{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "toc_visible": true,
      "machine_shape": "hm",
      "gpuType": "A100"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "e9038992d0a94189aea25e7aafd0d955": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_df5abe1d8cc54d809a7ca71c2b5cd72a",
              "IPY_MODEL_c3e956c9c95d4ef09c46c1d1d2aee6ae",
              "IPY_MODEL_0b814c5163ef48f7bd726ae215ba3ccb"
            ],
            "layout": "IPY_MODEL_6a5dd5f8b5ea4dac921ca3e5c0d6c7fe"
          }
        },
        "df5abe1d8cc54d809a7ca71c2b5cd72a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e481024b210446c9bfe6dd01c222badc",
            "placeholder": "​",
            "style": "IPY_MODEL_adcb43e6af794301af2d293dee02cfae",
            "value": "tokenizer_config.json: 100%"
          }
        },
        "c3e956c9c95d4ef09c46c1d1d2aee6ae": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e49e945c84da44e1a9e83c1b062895f2",
            "max": 49,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_c66fb5047a714baea3b60da040d4885d",
            "value": 49
          }
        },
        "0b814c5163ef48f7bd726ae215ba3ccb": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_4f2c2f387bd34b7a946a924d28d06512",
            "placeholder": "​",
            "style": "IPY_MODEL_9e8e9024a0e84af78fc232994ac9603b",
            "value": " 49.0/49.0 [00:00&lt;00:00, 6.19kB/s]"
          }
        },
        "6a5dd5f8b5ea4dac921ca3e5c0d6c7fe": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e481024b210446c9bfe6dd01c222badc": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "adcb43e6af794301af2d293dee02cfae": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "e49e945c84da44e1a9e83c1b062895f2": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c66fb5047a714baea3b60da040d4885d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "4f2c2f387bd34b7a946a924d28d06512": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "9e8e9024a0e84af78fc232994ac9603b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "c58dcf3b09564eefb270ca54d38b1ce3": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_69c264f9c97f49b789e7402f4425ca82",
              "IPY_MODEL_84bd07ee3e6e4c969d5953d8230e2191",
              "IPY_MODEL_cf19a9a733b141ed9d7647145c17bf82"
            ],
            "layout": "IPY_MODEL_221af6eaafb341f3b629b6058b6725d3"
          }
        },
        "69c264f9c97f49b789e7402f4425ca82": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_7b50eeeb182641779220739cb1a3911c",
            "placeholder": "​",
            "style": "IPY_MODEL_77bfee349ff14975a3aee2fbc0999399",
            "value": "vocab.txt: 100%"
          }
        },
        "84bd07ee3e6e4c969d5953d8230e2191": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c931f85d242b43ef9e1a0114e8e4b507",
            "max": 254728,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_236e7cc9c317461ca072d3b7292861de",
            "value": 254728
          }
        },
        "cf19a9a733b141ed9d7647145c17bf82": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_4837a31c4b8d409fa4e454b1b5ab32de",
            "placeholder": "​",
            "style": "IPY_MODEL_8b55796b7a474123aadd27d5175e245f",
            "value": " 255k/255k [00:00&lt;00:00, 589kB/s]"
          }
        },
        "221af6eaafb341f3b629b6058b6725d3": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7b50eeeb182641779220739cb1a3911c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "77bfee349ff14975a3aee2fbc0999399": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "c931f85d242b43ef9e1a0114e8e4b507": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "236e7cc9c317461ca072d3b7292861de": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "4837a31c4b8d409fa4e454b1b5ab32de": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8b55796b7a474123aadd27d5175e245f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "53a3f5de9ec3420d99649bd080580bcc": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_8a0cbccd477e4d729f6a231492480a1b",
              "IPY_MODEL_7287055921464efda101a065562fda70",
              "IPY_MODEL_22f8bf8d9d1c4dfda3c5095dd80fd5d6"
            ],
            "layout": "IPY_MODEL_4ee58848744f4ffc99799d85a7290786"
          }
        },
        "8a0cbccd477e4d729f6a231492480a1b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_529b644ec0e245cb80b88e3f0401f590",
            "placeholder": "​",
            "style": "IPY_MODEL_72bf61c3726447229cc207db31a2d3cf",
            "value": "tokenizer.json: 100%"
          }
        },
        "7287055921464efda101a065562fda70": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_b8ebcb0c9f2844c6aa6995da80343167",
            "max": 485115,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_6e44050baff942b9a72d61d8ed3a581d",
            "value": 485115
          }
        },
        "22f8bf8d9d1c4dfda3c5095dd80fd5d6": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_7997aebdc5a244ceb69bba6f91d2a53d",
            "placeholder": "​",
            "style": "IPY_MODEL_d17f4c4270944659b369bcbea60a1ccb",
            "value": " 485k/485k [00:00&lt;00:00, 13.8MB/s]"
          }
        },
        "4ee58848744f4ffc99799d85a7290786": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "529b644ec0e245cb80b88e3f0401f590": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "72bf61c3726447229cc207db31a2d3cf": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "b8ebcb0c9f2844c6aa6995da80343167": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6e44050baff942b9a72d61d8ed3a581d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "7997aebdc5a244ceb69bba6f91d2a53d": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d17f4c4270944659b369bcbea60a1ccb": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "0c1e20e2145d4281b633472ed9457467": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_b05effd9c89744f081bd286257b6a874",
              "IPY_MODEL_09d388cd36b74b54a6652a647b1f5951",
              "IPY_MODEL_78f0e7a863254ed98245fa498c78c838"
            ],
            "layout": "IPY_MODEL_52ad3d6aa2594f9491beba5e0ecbd3ad"
          }
        },
        "b05effd9c89744f081bd286257b6a874": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_090b34941f9541d7b50d08fa254774eb",
            "placeholder": "​",
            "style": "IPY_MODEL_b58a0bc45a21441f9c949e9f71b91c6c",
            "value": "config.json: 100%"
          }
        },
        "09d388cd36b74b54a6652a647b1f5951": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_51302eec9fc54c4a8d603c2903ebcbc3",
            "max": 433,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_5e892e68b5f84ca8bd9d28b5282674c8",
            "value": 433
          }
        },
        "78f0e7a863254ed98245fa498c78c838": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_f808a1633ee14b47b2013fff4b9e45b8",
            "placeholder": "​",
            "style": "IPY_MODEL_2c51ba3cfe13470e826182ea5e15c5b6",
            "value": " 433/433 [00:00&lt;00:00, 55.3kB/s]"
          }
        },
        "52ad3d6aa2594f9491beba5e0ecbd3ad": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "090b34941f9541d7b50d08fa254774eb": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b58a0bc45a21441f9c949e9f71b91c6c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "51302eec9fc54c4a8d603c2903ebcbc3": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "5e892e68b5f84ca8bd9d28b5282674c8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "f808a1633ee14b47b2013fff4b9e45b8": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2c51ba3cfe13470e826182ea5e15c5b6": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "891331f6d3fc44838aac504c0d0a09b0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_b0c7334da03a4d019697751e92e3b6d2",
              "IPY_MODEL_50e6edc971714e479c5ed4bde7f0e97b",
              "IPY_MODEL_a85a5b3d241945ffaaadabc516c0d262"
            ],
            "layout": "IPY_MODEL_a839295991cb428ab9a0b01bca4cd188"
          }
        },
        "b0c7334da03a4d019697751e92e3b6d2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_df4272233fd64460a5ef2cd526835e77",
            "placeholder": "​",
            "style": "IPY_MODEL_724ee2ed34454891b235f53279abef49",
            "value": "model.safetensors: 100%"
          }
        },
        "50e6edc971714e479c5ed4bde7f0e97b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c907e791e4dc495aa1aeab3c8aeafd5f",
            "max": 438844124,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_1d475887cba24d34ad44d02eb64f86bc",
            "value": 438844124
          }
        },
        "a85a5b3d241945ffaaadabc516c0d262": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_269564898d6b4c37acbf0ab9f3b3a07c",
            "placeholder": "​",
            "style": "IPY_MODEL_75c308037d4c405aaffe888fe5c82c43",
            "value": " 439M/439M [00:02&lt;00:00, 291MB/s]"
          }
        },
        "a839295991cb428ab9a0b01bca4cd188": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "df4272233fd64460a5ef2cd526835e77": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "724ee2ed34454891b235f53279abef49": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "c907e791e4dc495aa1aeab3c8aeafd5f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "1d475887cba24d34ad44d02eb64f86bc": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "269564898d6b4c37acbf0ab9f3b3a07c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "75c308037d4c405aaffe888fe5c82c43": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "134ce7ce58b245ffbb6341120c321f72": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_106d65461c3e40a08732d506b0b52bcd",
              "IPY_MODEL_3b0645d3b94044c5a25490e89d177326",
              "IPY_MODEL_36f31ddf7b93406d9da76c206fecf812"
            ],
            "layout": "IPY_MODEL_13d3d1305df442f89acc84af217ebc26"
          }
        },
        "106d65461c3e40a08732d506b0b52bcd": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_5955b8362ff64eb89aaff3e998c06364",
            "placeholder": "​",
            "style": "IPY_MODEL_24c80a73020d469aad49b1979ac05b62",
            "value": "Loading weights: 100%"
          }
        },
        "3b0645d3b94044c5a25490e89d177326": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_8ee757eb7908481f85ff5c536fc8a214",
            "max": 199,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_97dc681f605d48a58ec498f6042c3aab",
            "value": 199
          }
        },
        "36f31ddf7b93406d9da76c206fecf812": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_61474933d4564ae4ba82b7b7c9f3bc76",
            "placeholder": "​",
            "style": "IPY_MODEL_f423bb07cbc9412ba7af675097173edc",
            "value": " 199/199 [00:00&lt;00:00, 1039.20it/s, Materializing param=pooler.dense.weight]"
          }
        },
        "13d3d1305df442f89acc84af217ebc26": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "5955b8362ff64eb89aaff3e998c06364": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "24c80a73020d469aad49b1979ac05b62": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "8ee757eb7908481f85ff5c536fc8a214": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "97dc681f605d48a58ec498f6042c3aab": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "61474933d4564ae4ba82b7b7c9f3bc76": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "f423bb07cbc9412ba7af675097173edc": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "e507ccc6226844d581337d7b1376bc67": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_08d8f19593c3491f8d4a2675e9712c9a",
              "IPY_MODEL_1f6085c27430495daa40d2a0bd0f03af",
              "IPY_MODEL_c991712fbc484f3ba2f1ed943d301ef6"
            ],
            "layout": "IPY_MODEL_875d4f5d37fe4a5d883abed1995a78e7"
          }
        },
        "08d8f19593c3491f8d4a2675e9712c9a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_8a1bfb11b19143f3bfc9a2b85f9f96e7",
            "placeholder": "​",
            "style": "IPY_MODEL_7e57800838e54f87851f292fc95c1565",
            "value": "Loading weights: 100%"
          }
        },
        "1f6085c27430495daa40d2a0bd0f03af": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_6a32bc8713af4b4ba506c30fe4e1dd8b",
            "max": 199,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_bd653ceffd3c406f97d10fd6b39a8b63",
            "value": 199
          }
        },
        "c991712fbc484f3ba2f1ed943d301ef6": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_32bcfcb821f344ee8543157a68853029",
            "placeholder": "​",
            "style": "IPY_MODEL_ea6770024da94054b766fd4aa62c801b",
            "value": " 199/199 [00:00&lt;00:00, 768.55it/s, Materializing param=pooler.dense.weight]"
          }
        },
        "875d4f5d37fe4a5d883abed1995a78e7": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8a1bfb11b19143f3bfc9a2b85f9f96e7": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7e57800838e54f87851f292fc95c1565": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "6a32bc8713af4b4ba506c30fe4e1dd8b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "bd653ceffd3c406f97d10fd6b39a8b63": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "32bcfcb821f344ee8543157a68853029": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ea6770024da94054b766fd4aa62c801b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# **Reproduction of Group Classification**\n",
        "\n",
        "The reproduction of the group classification involves training large language models, which ideally requieres access to GPUs- These models have been trained with A100 GPUS.\n",
        "\n",
        "The reproduction includes to series of training, with an interim training data revision stream that involved human re-classification based on class-propabilities from the first round of training. Furthermore, synthetic examples for income thresholds have been added to teh training data of round 2 as described in teh Appendix of the Technical Report.\n",
        "\n",
        "**Data requirements**\n",
        "\n",
        "First training series:\n",
        "\n",
        "- Training data: all_sentences_df_clean_kw_checked.xlsx\n",
        "\n",
        "The model from first series can be used for teh first round prediction:\n",
        "\n",
        "  - Folder: bert_multilabel_group_model\n",
        "\n",
        "Second training series:\n",
        "\n",
        "- Training data: all_sentences_df_clean_kw_checked_added.xlsx\n",
        "\n",
        "The model from second series can be used for the final round prediction:\n",
        "\n",
        "- Folder: bert_multilabel_group_model_2\n",
        "\n",
        "Prediction:\n",
        "\n",
        "- Parties: parties_df_group_labels.xlsx\n",
        "- Media: media_df_group_label.xlsx\n",
        "\n",
        "Final revised prediction used for graphs in the Technical Report:\n",
        "Here we added a final revision round looking manually add the labels and cleaning the metadata.\n",
        "\n",
        "- Media: media_ml_predictions_with_metadata_revised.xlsx\n",
        "- Parties: ml_predictions_metadata.xlsx"
      ],
      "metadata": {
        "id": "vqgn7Ba3ubhP"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "zkv9pk7kRPoa",
        "outputId": "1cb2ca11-56ca-4e2d-efdc-4394e388f343"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mounted at /content/drive\n"
          ]
        }
      ],
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 1.First round sentence level training of bert-base-german-cased model"
      ],
      "metadata": {
        "id": "R0I74W055_Fs"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import pandas as pd\n",
        "group_training_df = pd.read_excel('/content/drive/all_sentences_df_clean_kw_checked.xlsx')"
      ],
      "metadata": {
        "id": "QFW6lSysRtZH"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 1.1 Extracting the labels"
      ],
      "metadata": {
        "id": "7b_0Xx7r6Mvv"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import pandas as pd\n",
        "import re\n",
        "\n",
        "\n",
        "# Function to extract all annotations like [G_l_m_neu]\n",
        "def extract_tags(text):\n",
        "    if isinstance(text, str): # Check if the text is a string\n",
        "        return re.findall(r'\\[G_([^\\]]+)\\]', text)\n",
        "    else:\n",
        "        return [] # Return an empty list for non-string values\n",
        "\n",
        "# Apply to create new column with list of annotation components\n",
        "group_training_df['tags'] = group_training_df['sentence'].apply(extract_tags)"
      ],
      "metadata": {
        "id": "jOk32u61RwKV"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Codebook mappings\n",
        "dimension_codes = {\n",
        "    'b': 'Ökonomisch - Bedürftig',\n",
        "    'w': 'Ökonomisch - Wohlhabend',\n",
        "    'm': 'Eigentümer - Mieter',\n",
        "    'v': 'Eigentümer - Vermieter/Eigentümer',\n",
        "    'e': 'Wohnstruktur - Einparteienhaus',\n",
        "    'mp': 'Wohnstruktur - Mehrparteienhaus',\n",
        "    'l': 'Geographie - Land',\n",
        "    's': 'Geographie - Stadt',\n",
        "    'o': 'Other - All other'\n",
        "}\n",
        "\n",
        "appeal_codes = {\n",
        "    'pos': 'positiv',\n",
        "    'neu': 'neutral',\n",
        "    'neg': 'negativ'\n",
        "}\n",
        "\n",
        "# Function to classify each tag\n",
        "def classify_tag_components(tag_str):\n",
        "    parts = tag_str.split('_')\n",
        "    result = {'dimensions': [], 'appeal': None}\n",
        "    for part in parts:\n",
        "        if part in appeal_codes:\n",
        "            result['appeal'] = appeal_codes[part]\n",
        "        elif part in dimension_codes:\n",
        "            result['dimensions'].append(dimension_codes[part])\n",
        "        else:\n",
        "            # For compound codes like 'mp' (Mehrparteienhaus)\n",
        "            compound_matches = [code for code in dimension_codes if code in part]\n",
        "            for match in compound_matches:\n",
        "                if match == part:\n",
        "                    result['dimensions'].append(dimension_codes[match])\n",
        "    return result\n",
        "\n",
        "# Expand all tag components into dimensions and appeal per row\n",
        "def analyze_tags(tag_list):\n",
        "    results = []\n",
        "    for tag in tag_list:\n",
        "        results.append(classify_tag_components(tag))\n",
        "    return results\n",
        "\n",
        "group_training_df['tag_details'] = group_training_df['tags'].apply(analyze_tags)"
      ],
      "metadata": {
        "id": "OERDRnHgRya7"
      },
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from collections import Counter\n",
        "\n",
        "all_dimensions = group_training_df['tag_details'].explode().dropna().apply(lambda x: x['dimensions'])\n",
        "flat_dimensions = [dim for sublist in all_dimensions.dropna() for dim in sublist]\n",
        "dimension_counts = Counter(flat_dimensions)\n",
        "print(dimension_counts)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Dr5tc7uER02v",
        "outputId": "5e9e5499-ac9a-4c9a-c210-ed981ff7a43b"
      },
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Counter({'Other - All other': 785, 'Eigentümer - Vermieter/Eigentümer': 400, 'Eigentümer - Mieter': 233, 'Ökonomisch - Bedürftig': 223, 'Ökonomisch - Wohlhabend': 62, 'Wohnstruktur - Mehrparteienhaus': 9, 'Geographie - Land': 8, 'Geographie - Stadt': 7, 'Wohnstruktur - Einparteienhaus': 6})\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "all_appeals = group_training_df['tag_details'].explode().dropna().apply(lambda x: x['appeal'])\n",
        "appeal_counts = Counter(all_appeals)\n",
        "print(appeal_counts)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "lqkr3R6vR2rH",
        "outputId": "65403cc9-ab7e-483d-aa36-d84b6efcd2ae"
      },
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Counter({'positiv': 958, 'neutral': 563, 'negativ': 111, None: 25})\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Create exploded version for analysis\n",
        "rows = []\n",
        "for idx, tag_list in group_training_df['tag_details'].items():\n",
        "    for tag in tag_list:\n",
        "        for dim in tag['dimensions']:\n",
        "            rows.append({\n",
        "                'article_id': idx,\n",
        "                'dimension': dim,\n",
        "                'appeal': tag['appeal']\n",
        "            })\n",
        "\n",
        "analysis_df = pd.DataFrame(rows)"
      ],
      "metadata": {
        "id": "xerL1R1oR4jq"
      },
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Frequency table\n",
        "pd.crosstab(analysis_df['dimension'], analysis_df['appeal'])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 363
        },
        "id": "Goww0Ex0R6lc",
        "outputId": "a78d252b-ec11-4be1-9548-cf0d151ac0c5"
      },
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "appeal                             negativ  neutral  positiv\n",
              "dimension                                                   \n",
              "Eigentümer - Mieter                      2       64      167\n",
              "Eigentümer - Vermieter/Eigentümer       19      182      199\n",
              "Geographie - Land                        0        1        7\n",
              "Geographie - Stadt                       0        5        2\n",
              "Other - All other                       32      283      447\n",
              "Wohnstruktur - Einparteienhaus           1        5        0\n",
              "Wohnstruktur - Mehrparteienhaus          0        5        4\n",
              "Ökonomisch - Bedürftig                  11       34      177\n",
              "Ökonomisch - Wohlhabend                 47       10        5"
            ],
            "text/html": [
              "\n",
              "  <div id=\"df-63b8b6a5-b0a4-46db-900e-002cc7ae7ae6\" class=\"colab-df-container\">\n",
              "    <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th>appeal</th>\n",
              "      <th>negativ</th>\n",
              "      <th>neutral</th>\n",
              "      <th>positiv</th>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>dimension</th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>Eigentümer - Mieter</th>\n",
              "      <td>2</td>\n",
              "      <td>64</td>\n",
              "      <td>167</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>Eigentümer - Vermieter/Eigentümer</th>\n",
              "      <td>19</td>\n",
              "      <td>182</td>\n",
              "      <td>199</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>Geographie - Land</th>\n",
              "      <td>0</td>\n",
              "      <td>1</td>\n",
              "      <td>7</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>Geographie - Stadt</th>\n",
              "      <td>0</td>\n",
              "      <td>5</td>\n",
              "      <td>2</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>Other - All other</th>\n",
              "      <td>32</td>\n",
              "      <td>283</td>\n",
              "      <td>447</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>Wohnstruktur - Einparteienhaus</th>\n",
              "      <td>1</td>\n",
              "      <td>5</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>Wohnstruktur - Mehrparteienhaus</th>\n",
              "      <td>0</td>\n",
              "      <td>5</td>\n",
              "      <td>4</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>Ökonomisch - Bedürftig</th>\n",
              "      <td>11</td>\n",
              "      <td>34</td>\n",
              "      <td>177</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>Ökonomisch - Wohlhabend</th>\n",
              "      <td>47</td>\n",
              "      <td>10</td>\n",
              "      <td>5</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "    <div class=\"colab-df-buttons\">\n",
              "\n",
              "  <div class=\"colab-df-container\">\n",
              "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-63b8b6a5-b0a4-46db-900e-002cc7ae7ae6')\"\n",
              "            title=\"Convert this dataframe to an interactive table.\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
              "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
              "  </svg>\n",
              "    </button>\n",
              "\n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    .colab-df-buttons div {\n",
              "      margin-bottom: 4px;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "    <script>\n",
              "      const buttonEl =\n",
              "        document.querySelector('#df-63b8b6a5-b0a4-46db-900e-002cc7ae7ae6 button.colab-df-convert');\n",
              "      buttonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "      async function convertToInteractive(key) {\n",
              "        const element = document.querySelector('#df-63b8b6a5-b0a4-46db-900e-002cc7ae7ae6');\n",
              "        const dataTable =\n",
              "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                    [key], {});\n",
              "        if (!dataTable) return;\n",
              "\n",
              "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "          + ' to learn more about interactive tables.';\n",
              "        element.innerHTML = '';\n",
              "        dataTable['output_type'] = 'display_data';\n",
              "        await google.colab.output.renderOutput(dataTable, element);\n",
              "        const docLink = document.createElement('div');\n",
              "        docLink.innerHTML = docLinkHtml;\n",
              "        element.appendChild(docLink);\n",
              "      }\n",
              "    </script>\n",
              "  </div>\n",
              "\n",
              "\n",
              "    </div>\n",
              "  </div>\n"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "dataframe",
              "summary": "{\n  \"name\": \"pd\",\n  \"rows\": 9,\n  \"fields\": [\n    {\n      \"column\": \"dimension\",\n      \"properties\": {\n        \"dtype\": \"string\",\n        \"num_unique_values\": 9,\n        \"samples\": [\n          \"\\u00d6konomisch - Bed\\u00fcrftig\",\n          \"Eigent\\u00fcmer - Vermieter/Eigent\\u00fcmer\",\n          \"Wohnstruktur - Einparteienhaus\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"negativ\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 17,\n        \"min\": 0,\n        \"max\": 47,\n        \"num_unique_values\": 7,\n        \"samples\": [\n          2,\n          19,\n          11\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"neutral\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 100,\n        \"min\": 1,\n        \"max\": 283,\n        \"num_unique_values\": 7,\n        \"samples\": [\n          64,\n          182,\n          34\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"positiv\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 152,\n        \"min\": 0,\n        \"max\": 447,\n        \"num_unique_values\": 9,\n        \"samples\": [\n          177,\n          199,\n          0\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    }\n  ]\n}"
            }
          },
          "metadata": {},
          "execution_count": 8
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "freq_table = pd.crosstab(analysis_df['dimension'], analysis_df['appeal'])"
      ],
      "metadata": {
        "id": "fpdjB35K1QCn"
      },
      "execution_count": 13,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "percentage_table = freq_table.div(freq_table.sum(axis=1), axis=0) * 100"
      ],
      "metadata": {
        "id": "7_vJgFBG1SNw"
      },
      "execution_count": 14,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "percentage_table_col = freq_table.div(freq_table.sum(axis=0), axis=1) * 100"
      ],
      "metadata": {
        "id": "eaK_hcLG1UXX"
      },
      "execution_count": 15,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "dimension_counts = analysis_df['dimension'].value_counts(normalize=True) * 100"
      ],
      "metadata": {
        "id": "fYMmsWO_1Wa1"
      },
      "execution_count": 16,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "summary = freq_table.copy()\n",
        "summary['Total'] = summary.sum(axis=1)\n",
        "summary['% of all mentions'] = summary['Total'] / summary['Total'].sum() * 100"
      ],
      "metadata": {
        "id": "G9-d6Ifj1Ym9"
      },
      "execution_count": 17,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "The frequencies derived from the following code are used in Appendix part A.3.1 to describe which group appeals are relevant.Note: They do not reflect the numbers in table A.3.3 that is based on teh final revised training data."
      ],
      "metadata": {
        "id": "0csHNCEi3zmL"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Step 1: Frequency table (dimension x appeal)\n",
        "freq_table = pd.crosstab(analysis_df['dimension'], analysis_df['appeal'])\n",
        "\n",
        "# Step 2: Add total mentions per dimension\n",
        "freq_table['Total'] = freq_table.sum(axis=1)\n",
        "\n",
        "# Step 3: Add percentage of all mentions\n",
        "freq_table['% of all mentions'] = freq_table['Total'] / freq_table['Total'].sum() * 100\n",
        "\n",
        "# Step 4: Round for readability\n",
        "freq_table['% of all mentions'] = freq_table['% of all mentions'].round(2)\n",
        "\n",
        "# Step 5: Print the final table\n",
        "print(freq_table)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "TBDEJUOw1a-Z",
        "outputId": "1b53bd97-af5e-4833-964a-8bc86cc9fc30"
      },
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "appeal                             negativ  neutral  positiv  Total  \\\n",
            "dimension                                                             \n",
            "Eigentümer - Mieter                      2       64      167    233   \n",
            "Eigentümer - Vermieter/Eigentümer       19      182      199    400   \n",
            "Geographie - Land                        0        1        7      8   \n",
            "Geographie - Stadt                       0        5        2      7   \n",
            "Other - All other                       32      283      447    762   \n",
            "Wohnstruktur - Einparteienhaus           1        5        0      6   \n",
            "Wohnstruktur - Mehrparteienhaus          0        5        4      9   \n",
            "Ökonomisch - Bedürftig                  11       34      177    222   \n",
            "Ökonomisch - Wohlhabend                 47       10        5     62   \n",
            "\n",
            "appeal                             % of all mentions  \n",
            "dimension                                             \n",
            "Eigentümer - Mieter                            13.63  \n",
            "Eigentümer - Vermieter/Eigentümer              23.41  \n",
            "Geographie - Land                               0.47  \n",
            "Geographie - Stadt                              0.41  \n",
            "Other - All other                              44.59  \n",
            "Wohnstruktur - Einparteienhaus                  0.35  \n",
            "Wohnstruktur - Mehrparteienhaus                 0.53  \n",
            "Ökonomisch - Bedürftig                         12.99  \n",
            "Ökonomisch - Wohlhabend                         3.63  \n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Extract sentence-level target labels | create binary label columns (or a multi-hot vector) for each of the four target dimensions."
      ],
      "metadata": {
        "id": "cPnM-ei6SN74"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Define label mapping\n",
        "label_map = {\n",
        "    'b': 'Bedürftig',\n",
        "    'w': 'Wohlhabend',\n",
        "    'm': 'Mieter',\n",
        "    'v': 'Vermieter/Eigentümer'\n",
        "}\n",
        "\n",
        "target_labels = list(label_map.values())\n",
        "\n",
        "# Function to extract the relevant dimensions (b, w, m, v)\n",
        "def extract_target_dimensions(tag_details):\n",
        "    labels = set()\n",
        "    for detail in tag_details:\n",
        "        for dim in detail['dimensions']:\n",
        "            for code, name in label_map.items():\n",
        "                if dimension_codes[code] == dim:\n",
        "                    labels.add(name)\n",
        "    return list(labels)\n",
        "\n",
        "# Apply to get labels\n",
        "group_training_df['sentence_labels'] = group_training_df['tag_details'].apply(extract_target_dimensions)"
      ],
      "metadata": {
        "id": "IGWO53dQSM4Q"
      },
      "execution_count": 21,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Create multi-hot encoded label columns\n",
        "for label in target_labels:\n",
        "    group_training_df[label] = group_training_df['sentence_labels'].apply(lambda x: label in x)"
      ],
      "metadata": {
        "id": "uqiYgGGGSTBh"
      },
      "execution_count": 22,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Function to remove [G_...] annotation tags from text\n",
        "def remove_tags(text):\n",
        "    if isinstance(text, str):\n",
        "        return re.sub(r'\\[G_[^\\]]+\\]', '', text).strip()\n",
        "    return text\n",
        "\n",
        "# Apply to create a clean sentence column\n",
        "group_training_df['clean_sentence'] = group_training_df['sentence'].apply(remove_tags)"
      ],
      "metadata": {
        "id": "I4mlbsizSVZ_"
      },
      "execution_count": 23,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 1.2 Train and test split"
      ],
      "metadata": {
        "id": "2TBeoHkkSX69"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "print(\"Label distribution in the full dataset:\")\n",
        "print(group_training_df[target_labels].sum().sort_values(ascending=False))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "uI8U8dbzSXRa",
        "outputId": "b978de2d-e424-4fbc-c67b-fef5ee30fc05"
      },
      "execution_count": 24,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Label distribution in the full dataset:\n",
            "Vermieter/Eigentümer    379\n",
            "Mieter                  225\n",
            "Bedürftig               203\n",
            "Wohlhabend               60\n",
            "dtype: int64\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from sklearn.model_selection import train_test_split\n",
        "\n",
        "X = group_training_df['clean_sentence']\n",
        "y = group_training_df[target_labels]\n",
        "\n",
        "X_train, X_test, y_train, y_test = train_test_split(\n",
        "    X, y, test_size=0.2, random_state=42\n",
        ")\n"
      ],
      "metadata": {
        "id": "CTjiA3kCSdCQ"
      },
      "execution_count": 25,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(\"\\nLabel distribution in the training set:\")\n",
        "print(y_train.sum().sort_values(ascending=False))\n",
        "\n",
        "print(\"\\nLabel distribution in the test set:\")\n",
        "print(y_test.sum().sort_values(ascending=False))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2DDY2C0VSfHJ",
        "outputId": "97e9dd49-78a2-4ad1-8765-6db0fbeae96a"
      },
      "execution_count": 26,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "Label distribution in the training set:\n",
            "Vermieter/Eigentümer    303\n",
            "Mieter                  178\n",
            "Bedürftig               156\n",
            "Wohlhabend               50\n",
            "dtype: int64\n",
            "\n",
            "Label distribution in the test set:\n",
            "Vermieter/Eigentümer    76\n",
            "Bedürftig               47\n",
            "Mieter                  47\n",
            "Wohlhabend              10\n",
            "dtype: int64\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "print(\"\\nLabel percentage in training set:\")\n",
        "print((y_train.sum() / len(y_train) * 100).round(2))\n",
        "\n",
        "print(\"\\nLabel percentage in test set:\")\n",
        "print((y_test.sum() / len(y_test) * 100).round(2))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hl2flsQjSkmK",
        "outputId": "842ae308-206b-4b9e-b614-fdc70d41d961"
      },
      "execution_count": 27,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "Label percentage in training set:\n",
            "Bedürftig               2.29\n",
            "Wohlhabend              0.74\n",
            "Mieter                  2.62\n",
            "Vermieter/Eigentümer    4.46\n",
            "dtype: float64\n",
            "\n",
            "Label percentage in test set:\n",
            "Bedürftig               2.76\n",
            "Wohlhabend              0.59\n",
            "Mieter                  2.76\n",
            "Vermieter/Eigentümer    4.47\n",
            "dtype: float64\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 1.3 Vectorization"
      ],
      "metadata": {
        "id": "A-oLLdZ8Smwz"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from sklearn.feature_extraction.text import TfidfVectorizer\n",
        "\n",
        "vectorizer = TfidfVectorizer(max_features=5000, ngram_range=(1, 2))\n",
        "X_train_vec = vectorizer.fit_transform(X_train)\n",
        "X_test_vec = vectorizer.transform(X_test)"
      ],
      "metadata": {
        "id": "9Gz3uk6ySmmv"
      },
      "execution_count": 28,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 1.4 Logistic regression as a benchmark"
      ],
      "metadata": {
        "id": "S9TYTODh6dnH"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from sklearn.linear_model import LogisticRegression\n",
        "from sklearn.multiclass import OneVsRestClassifier\n",
        "\n",
        "clf = OneVsRestClassifier(\n",
        "    LogisticRegression(class_weight='balanced', max_iter=1000, solver='liblinear')\n",
        ")\n",
        "\n",
        "clf.fit(X_train_vec, y_train)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 167
        },
        "id": "piqkzCFuSriu",
        "outputId": "88b3fd82-5fd9-459f-cbd0-6ec1e4c26e22"
      },
      "execution_count": 29,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "OneVsRestClassifier(estimator=LogisticRegression(class_weight='balanced',\n",
              "                                                 max_iter=1000,\n",
              "                                                 solver='liblinear'))"
            ],
            "text/html": [
              "<style>#sk-container-id-1 {\n",
              "  /* Definition of color scheme common for light and dark mode */\n",
              "  --sklearn-color-text: #000;\n",
              "  --sklearn-color-text-muted: #666;\n",
              "  --sklearn-color-line: gray;\n",
              "  /* Definition of color scheme for unfitted estimators */\n",
              "  --sklearn-color-unfitted-level-0: #fff5e6;\n",
              "  --sklearn-color-unfitted-level-1: #f6e4d2;\n",
              "  --sklearn-color-unfitted-level-2: #ffe0b3;\n",
              "  --sklearn-color-unfitted-level-3: chocolate;\n",
              "  /* Definition of color scheme for fitted estimators */\n",
              "  --sklearn-color-fitted-level-0: #f0f8ff;\n",
              "  --sklearn-color-fitted-level-1: #d4ebff;\n",
              "  --sklearn-color-fitted-level-2: #b3dbfd;\n",
              "  --sklearn-color-fitted-level-3: cornflowerblue;\n",
              "\n",
              "  /* Specific color for light theme */\n",
              "  --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
              "  --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
              "  --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
              "  --sklearn-color-icon: #696969;\n",
              "\n",
              "  @media (prefers-color-scheme: dark) {\n",
              "    /* Redefinition of color scheme for dark theme */\n",
              "    --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
              "    --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
              "    --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
              "    --sklearn-color-icon: #878787;\n",
              "  }\n",
              "}\n",
              "\n",
              "#sk-container-id-1 {\n",
              "  color: var(--sklearn-color-text);\n",
              "}\n",
              "\n",
              "#sk-container-id-1 pre {\n",
              "  padding: 0;\n",
              "}\n",
              "\n",
              "#sk-container-id-1 input.sk-hidden--visually {\n",
              "  border: 0;\n",
              "  clip: rect(1px 1px 1px 1px);\n",
              "  clip: rect(1px, 1px, 1px, 1px);\n",
              "  height: 1px;\n",
              "  margin: -1px;\n",
              "  overflow: hidden;\n",
              "  padding: 0;\n",
              "  position: absolute;\n",
              "  width: 1px;\n",
              "}\n",
              "\n",
              "#sk-container-id-1 div.sk-dashed-wrapped {\n",
              "  border: 1px dashed var(--sklearn-color-line);\n",
              "  margin: 0 0.4em 0.5em 0.4em;\n",
              "  box-sizing: border-box;\n",
              "  padding-bottom: 0.4em;\n",
              "  background-color: var(--sklearn-color-background);\n",
              "}\n",
              "\n",
              "#sk-container-id-1 div.sk-container {\n",
              "  /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
              "     but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
              "     so we also need the `!important` here to be able to override the\n",
              "     default hidden behavior on the sphinx rendered scikit-learn.org.\n",
              "     See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
              "  display: inline-block !important;\n",
              "  position: relative;\n",
              "}\n",
              "\n",
              "#sk-container-id-1 div.sk-text-repr-fallback {\n",
              "  display: none;\n",
              "}\n",
              "\n",
              "div.sk-parallel-item,\n",
              "div.sk-serial,\n",
              "div.sk-item {\n",
              "  /* draw centered vertical line to link estimators */\n",
              "  background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
              "  background-size: 2px 100%;\n",
              "  background-repeat: no-repeat;\n",
              "  background-position: center center;\n",
              "}\n",
              "\n",
              "/* Parallel-specific style estimator block */\n",
              "\n",
              "#sk-container-id-1 div.sk-parallel-item::after {\n",
              "  content: \"\";\n",
              "  width: 100%;\n",
              "  border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
              "  flex-grow: 1;\n",
              "}\n",
              "\n",
              "#sk-container-id-1 div.sk-parallel {\n",
              "  display: flex;\n",
              "  align-items: stretch;\n",
              "  justify-content: center;\n",
              "  background-color: var(--sklearn-color-background);\n",
              "  position: relative;\n",
              "}\n",
              "\n",
              "#sk-container-id-1 div.sk-parallel-item {\n",
              "  display: flex;\n",
              "  flex-direction: column;\n",
              "}\n",
              "\n",
              "#sk-container-id-1 div.sk-parallel-item:first-child::after {\n",
              "  align-self: flex-end;\n",
              "  width: 50%;\n",
              "}\n",
              "\n",
              "#sk-container-id-1 div.sk-parallel-item:last-child::after {\n",
              "  align-self: flex-start;\n",
              "  width: 50%;\n",
              "}\n",
              "\n",
              "#sk-container-id-1 div.sk-parallel-item:only-child::after {\n",
              "  width: 0;\n",
              "}\n",
              "\n",
              "/* Serial-specific style estimator block */\n",
              "\n",
              "#sk-container-id-1 div.sk-serial {\n",
              "  display: flex;\n",
              "  flex-direction: column;\n",
              "  align-items: center;\n",
              "  background-color: var(--sklearn-color-background);\n",
              "  padding-right: 1em;\n",
              "  padding-left: 1em;\n",
              "}\n",
              "\n",
              "\n",
              "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
              "clickable and can be expanded/collapsed.\n",
              "- Pipeline and ColumnTransformer use this feature and define the default style\n",
              "- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
              "*/\n",
              "\n",
              "/* Pipeline and ColumnTransformer style (default) */\n",
              "\n",
              "#sk-container-id-1 div.sk-toggleable {\n",
              "  /* Default theme specific background. It is overwritten whether we have a\n",
              "  specific estimator or a Pipeline/ColumnTransformer */\n",
              "  background-color: var(--sklearn-color-background);\n",
              "}\n",
              "\n",
              "/* Toggleable label */\n",
              "#sk-container-id-1 label.sk-toggleable__label {\n",
              "  cursor: pointer;\n",
              "  display: flex;\n",
              "  width: 100%;\n",
              "  margin-bottom: 0;\n",
              "  padding: 0.5em;\n",
              "  box-sizing: border-box;\n",
              "  text-align: center;\n",
              "  align-items: start;\n",
              "  justify-content: space-between;\n",
              "  gap: 0.5em;\n",
              "}\n",
              "\n",
              "#sk-container-id-1 label.sk-toggleable__label .caption {\n",
              "  font-size: 0.6rem;\n",
              "  font-weight: lighter;\n",
              "  color: var(--sklearn-color-text-muted);\n",
              "}\n",
              "\n",
              "#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n",
              "  /* Arrow on the left of the label */\n",
              "  content: \"▸\";\n",
              "  float: left;\n",
              "  margin-right: 0.25em;\n",
              "  color: var(--sklearn-color-icon);\n",
              "}\n",
              "\n",
              "#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n",
              "  color: var(--sklearn-color-text);\n",
              "}\n",
              "\n",
              "/* Toggleable content - dropdown */\n",
              "\n",
              "#sk-container-id-1 div.sk-toggleable__content {\n",
              "  max-height: 0;\n",
              "  max-width: 0;\n",
              "  overflow: hidden;\n",
              "  text-align: left;\n",
              "  /* unfitted */\n",
              "  background-color: var(--sklearn-color-unfitted-level-0);\n",
              "}\n",
              "\n",
              "#sk-container-id-1 div.sk-toggleable__content.fitted {\n",
              "  /* fitted */\n",
              "  background-color: var(--sklearn-color-fitted-level-0);\n",
              "}\n",
              "\n",
              "#sk-container-id-1 div.sk-toggleable__content pre {\n",
              "  margin: 0.2em;\n",
              "  border-radius: 0.25em;\n",
              "  color: var(--sklearn-color-text);\n",
              "  /* unfitted */\n",
              "  background-color: var(--sklearn-color-unfitted-level-0);\n",
              "}\n",
              "\n",
              "#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n",
              "  /* unfitted */\n",
              "  background-color: var(--sklearn-color-fitted-level-0);\n",
              "}\n",
              "\n",
              "#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
              "  /* Expand drop-down */\n",
              "  max-height: 200px;\n",
              "  max-width: 100%;\n",
              "  overflow: auto;\n",
              "}\n",
              "\n",
              "#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
              "  content: \"▾\";\n",
              "}\n",
              "\n",
              "/* Pipeline/ColumnTransformer-specific style */\n",
              "\n",
              "#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
              "  color: var(--sklearn-color-text);\n",
              "  background-color: var(--sklearn-color-unfitted-level-2);\n",
              "}\n",
              "\n",
              "#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
              "  background-color: var(--sklearn-color-fitted-level-2);\n",
              "}\n",
              "\n",
              "/* Estimator-specific style */\n",
              "\n",
              "/* Colorize estimator box */\n",
              "#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
              "  /* unfitted */\n",
              "  background-color: var(--sklearn-color-unfitted-level-2);\n",
              "}\n",
              "\n",
              "#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
              "  /* fitted */\n",
              "  background-color: var(--sklearn-color-fitted-level-2);\n",
              "}\n",
              "\n",
              "#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n",
              "#sk-container-id-1 div.sk-label label {\n",
              "  /* The background is the default theme color */\n",
              "  color: var(--sklearn-color-text-on-default-background);\n",
              "}\n",
              "\n",
              "/* On hover, darken the color of the background */\n",
              "#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n",
              "  color: var(--sklearn-color-text);\n",
              "  background-color: var(--sklearn-color-unfitted-level-2);\n",
              "}\n",
              "\n",
              "/* Label box, darken color on hover, fitted */\n",
              "#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
              "  color: var(--sklearn-color-text);\n",
              "  background-color: var(--sklearn-color-fitted-level-2);\n",
              "}\n",
              "\n",
              "/* Estimator label */\n",
              "\n",
              "#sk-container-id-1 div.sk-label label {\n",
              "  font-family: monospace;\n",
              "  font-weight: bold;\n",
              "  display: inline-block;\n",
              "  line-height: 1.2em;\n",
              "}\n",
              "\n",
              "#sk-container-id-1 div.sk-label-container {\n",
              "  text-align: center;\n",
              "}\n",
              "\n",
              "/* Estimator-specific */\n",
              "#sk-container-id-1 div.sk-estimator {\n",
              "  font-family: monospace;\n",
              "  border: 1px dotted var(--sklearn-color-border-box);\n",
              "  border-radius: 0.25em;\n",
              "  box-sizing: border-box;\n",
              "  margin-bottom: 0.5em;\n",
              "  /* unfitted */\n",
              "  background-color: var(--sklearn-color-unfitted-level-0);\n",
              "}\n",
              "\n",
              "#sk-container-id-1 div.sk-estimator.fitted {\n",
              "  /* fitted */\n",
              "  background-color: var(--sklearn-color-fitted-level-0);\n",
              "}\n",
              "\n",
              "/* on hover */\n",
              "#sk-container-id-1 div.sk-estimator:hover {\n",
              "  /* unfitted */\n",
              "  background-color: var(--sklearn-color-unfitted-level-2);\n",
              "}\n",
              "\n",
              "#sk-container-id-1 div.sk-estimator.fitted:hover {\n",
              "  /* fitted */\n",
              "  background-color: var(--sklearn-color-fitted-level-2);\n",
              "}\n",
              "\n",
              "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
              "\n",
              "/* Common style for \"i\" and \"?\" */\n",
              "\n",
              ".sk-estimator-doc-link,\n",
              "a:link.sk-estimator-doc-link,\n",
              "a:visited.sk-estimator-doc-link {\n",
              "  float: right;\n",
              "  font-size: smaller;\n",
              "  line-height: 1em;\n",
              "  font-family: monospace;\n",
              "  background-color: var(--sklearn-color-background);\n",
              "  border-radius: 1em;\n",
              "  height: 1em;\n",
              "  width: 1em;\n",
              "  text-decoration: none !important;\n",
              "  margin-left: 0.5em;\n",
              "  text-align: center;\n",
              "  /* unfitted */\n",
              "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
              "  color: var(--sklearn-color-unfitted-level-1);\n",
              "}\n",
              "\n",
              ".sk-estimator-doc-link.fitted,\n",
              "a:link.sk-estimator-doc-link.fitted,\n",
              "a:visited.sk-estimator-doc-link.fitted {\n",
              "  /* fitted */\n",
              "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
              "  color: var(--sklearn-color-fitted-level-1);\n",
              "}\n",
              "\n",
              "/* On hover */\n",
              "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
              ".sk-estimator-doc-link:hover,\n",
              "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
              ".sk-estimator-doc-link:hover {\n",
              "  /* unfitted */\n",
              "  background-color: var(--sklearn-color-unfitted-level-3);\n",
              "  color: var(--sklearn-color-background);\n",
              "  text-decoration: none;\n",
              "}\n",
              "\n",
              "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
              ".sk-estimator-doc-link.fitted:hover,\n",
              "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
              ".sk-estimator-doc-link.fitted:hover {\n",
              "  /* fitted */\n",
              "  background-color: var(--sklearn-color-fitted-level-3);\n",
              "  color: var(--sklearn-color-background);\n",
              "  text-decoration: none;\n",
              "}\n",
              "\n",
              "/* Span, style for the box shown on hovering the info icon */\n",
              ".sk-estimator-doc-link span {\n",
              "  display: none;\n",
              "  z-index: 9999;\n",
              "  position: relative;\n",
              "  font-weight: normal;\n",
              "  right: .2ex;\n",
              "  padding: .5ex;\n",
              "  margin: .5ex;\n",
              "  width: min-content;\n",
              "  min-width: 20ex;\n",
              "  max-width: 50ex;\n",
              "  color: var(--sklearn-color-text);\n",
              "  box-shadow: 2pt 2pt 4pt #999;\n",
              "  /* unfitted */\n",
              "  background: var(--sklearn-color-unfitted-level-0);\n",
              "  border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
              "}\n",
              "\n",
              ".sk-estimator-doc-link.fitted span {\n",
              "  /* fitted */\n",
              "  background: var(--sklearn-color-fitted-level-0);\n",
              "  border: var(--sklearn-color-fitted-level-3);\n",
              "}\n",
              "\n",
              ".sk-estimator-doc-link:hover span {\n",
              "  display: block;\n",
              "}\n",
              "\n",
              "/* \"?\"-specific style due to the `<a>` HTML tag */\n",
              "\n",
              "#sk-container-id-1 a.estimator_doc_link {\n",
              "  float: right;\n",
              "  font-size: 1rem;\n",
              "  line-height: 1em;\n",
              "  font-family: monospace;\n",
              "  background-color: var(--sklearn-color-background);\n",
              "  border-radius: 1rem;\n",
              "  height: 1rem;\n",
              "  width: 1rem;\n",
              "  text-decoration: none;\n",
              "  /* unfitted */\n",
              "  color: var(--sklearn-color-unfitted-level-1);\n",
              "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
              "}\n",
              "\n",
              "#sk-container-id-1 a.estimator_doc_link.fitted {\n",
              "  /* fitted */\n",
              "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
              "  color: var(--sklearn-color-fitted-level-1);\n",
              "}\n",
              "\n",
              "/* On hover */\n",
              "#sk-container-id-1 a.estimator_doc_link:hover {\n",
              "  /* unfitted */\n",
              "  background-color: var(--sklearn-color-unfitted-level-3);\n",
              "  color: var(--sklearn-color-background);\n",
              "  text-decoration: none;\n",
              "}\n",
              "\n",
              "#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n",
              "  /* fitted */\n",
              "  background-color: var(--sklearn-color-fitted-level-3);\n",
              "}\n",
              "</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>OneVsRestClassifier(estimator=LogisticRegression(class_weight=&#x27;balanced&#x27;,\n",
              "                                                 max_iter=1000,\n",
              "                                                 solver=&#x27;liblinear&#x27;))</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" ><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow\"><div><div>OneVsRestClassifier</div></div><div><a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.6/modules/generated/sklearn.multiclass.OneVsRestClassifier.html\">?<span>Documentation for OneVsRestClassifier</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></div></label><div class=\"sk-toggleable__content fitted\"><pre>OneVsRestClassifier(estimator=LogisticRegression(class_weight=&#x27;balanced&#x27;,\n",
              "                                                 max_iter=1000,\n",
              "                                                 solver=&#x27;liblinear&#x27;))</pre></div> </div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" ><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow\"><div><div>estimator: LogisticRegression</div></div></label><div class=\"sk-toggleable__content fitted\"><pre>LogisticRegression(class_weight=&#x27;balanced&#x27;, max_iter=1000, solver=&#x27;liblinear&#x27;)</pre></div> </div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" ><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow\"><div><div>LogisticRegression</div></div><div><a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.6/modules/generated/sklearn.linear_model.LogisticRegression.html\">?<span>Documentation for LogisticRegression</span></a></div></label><div class=\"sk-toggleable__content fitted\"><pre>LogisticRegression(class_weight=&#x27;balanced&#x27;, max_iter=1000, solver=&#x27;liblinear&#x27;)</pre></div> </div></div></div></div></div></div></div></div></div>"
            ]
          },
          "metadata": {},
          "execution_count": 29
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from sklearn.metrics import classification_report, f1_score, precision_score, recall_score\n",
        "\n",
        "y_pred = clf.predict(X_test_vec)\n",
        "\n",
        "# Per-class precision/recall/F1\n",
        "print(classification_report(y_test, y_pred, target_names=target_labels))\n",
        "\n",
        "# Macro F1 (averages per class)\n",
        "macro_f1 = f1_score(y_test, y_pred, average='macro')\n",
        "print(f\"Macro F1 score: {macro_f1:.4f}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "cGTk9GS8StYD",
        "outputId": "ec921d01-5eea-40ca-8f93-83bcc5a486c5"
      },
      "execution_count": 30,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "                      precision    recall  f1-score   support\n",
            "\n",
            "           Bedürftig       0.40      0.57      0.47        47\n",
            "          Wohlhabend       0.60      0.30      0.40        10\n",
            "              Mieter       0.94      0.96      0.95        47\n",
            "Vermieter/Eigentümer       0.74      0.76      0.75        76\n",
            "\n",
            "           micro avg       0.67      0.74      0.70       180\n",
            "           macro avg       0.67      0.65      0.64       180\n",
            "        weighted avg       0.70      0.74      0.71       180\n",
            "         samples avg       0.06      0.07      0.06       180\n",
            "\n",
            "Macro F1 score: 0.6425\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in samples with no predicted labels. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
            "/usr/local/lib/python3.12/dist-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in samples with no true labels. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
            "/usr/local/lib/python3.12/dist-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in samples with no true nor predicted labels. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 1.5 Training of bert-base_german_cased"
      ],
      "metadata": {
        "id": "vw1-mHyjSv2-"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "pip install transformers datasets scikit-learn"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ZO_T6ijTSxzL",
        "outputId": "705e6f60-f0aa-4aef-b975-2e46538b30f4"
      },
      "execution_count": 31,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Requirement already satisfied: transformers in /usr/local/lib/python3.12/dist-packages (5.0.0)\n",
            "Requirement already satisfied: datasets in /usr/local/lib/python3.12/dist-packages (4.0.0)\n",
            "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.12/dist-packages (1.6.1)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from transformers) (3.20.3)\n",
            "Requirement already satisfied: huggingface-hub<2.0,>=1.3.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (1.4.0)\n",
            "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from transformers) (2.0.2)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (26.0)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from transformers) (6.0.3)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from transformers) (2025.11.3)\n",
            "Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.22.2)\n",
            "Requirement already satisfied: typer-slim in /usr/local/lib/python3.12/dist-packages (from transformers) (0.21.1)\n",
            "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.7.0)\n",
            "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.12/dist-packages (from transformers) (4.67.3)\n",
            "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (18.1.0)\n",
            "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (0.3.8)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from datasets) (2.2.2)\n",
            "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.12/dist-packages (from datasets) (2.32.4)\n",
            "Requirement already satisfied: xxhash in /usr/local/lib/python3.12/dist-packages (from datasets) (3.6.0)\n",
            "Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.12/dist-packages (from datasets) (0.70.16)\n",
            "Requirement already satisfied: fsspec<=2025.3.0,>=2023.1.0 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (2025.3.0)\n",
            "Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn) (1.16.3)\n",
            "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn) (1.5.3)\n",
            "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn) (3.6.0)\n",
            "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (3.13.3)\n",
            "Requirement already satisfied: hf-xet<2.0.0,>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<2.0,>=1.3.0->transformers) (1.2.0)\n",
            "Requirement already satisfied: httpx<1,>=0.23.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<2.0,>=1.3.0->transformers) (0.28.1)\n",
            "Requirement already satisfied: shellingham in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<2.0,>=1.3.0->transformers) (1.5.4)\n",
            "Requirement already satisfied: typing-extensions>=4.1.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<2.0,>=1.3.0->transformers) (4.15.0)\n",
            "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (3.4.4)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (3.11)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (2.5.0)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (2026.1.4)\n",
            "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2.9.0.post0)\n",
            "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2025.2)\n",
            "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2025.3)\n",
            "Requirement already satisfied: click>=8.0.0 in /usr/local/lib/python3.12/dist-packages (from typer-slim->transformers) (8.3.1)\n",
            "Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (2.6.1)\n",
            "Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.4.0)\n",
            "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (25.4.0)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.8.0)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (6.7.1)\n",
            "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (0.4.1)\n",
            "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.22.0)\n",
            "Requirement already satisfied: anyio in /usr/local/lib/python3.12/dist-packages (from httpx<1,>=0.23.0->huggingface-hub<2.0,>=1.3.0->transformers) (4.12.1)\n",
            "Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.12/dist-packages (from httpx<1,>=0.23.0->huggingface-hub<2.0,>=1.3.0->transformers) (1.0.9)\n",
            "Requirement already satisfied: h11>=0.16 in /usr/local/lib/python3.12/dist-packages (from httpcore==1.*->httpx<1,>=0.23.0->huggingface-hub<2.0,>=1.3.0->transformers) (0.16.0)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from sklearn.model_selection import train_test_split\n",
        "from transformers import BertTokenizer\n",
        "\n",
        "# HuggingFace model + tokenizer\n",
        "model_name = \"bert-base-german-cased\"\n",
        "tokenizer = BertTokenizer.from_pretrained(model_name)\n",
        "\n",
        "# Inputs and labels\n",
        "X = group_training_df['clean_sentence'].tolist()\n",
        "y = group_training_df[target_labels].values.astype(float)\n",
        "\n",
        "# Split\n",
        "X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)\n",
        "\n",
        "# Tokenize\n",
        "train_encodings = tokenizer(X_train, truncation=True, padding=True, max_length=128)\n",
        "val_encodings = tokenizer(X_val, truncation=True, padding=True, max_length=128)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 252,
          "referenced_widgets": [
            "e9038992d0a94189aea25e7aafd0d955",
            "df5abe1d8cc54d809a7ca71c2b5cd72a",
            "c3e956c9c95d4ef09c46c1d1d2aee6ae",
            "0b814c5163ef48f7bd726ae215ba3ccb",
            "6a5dd5f8b5ea4dac921ca3e5c0d6c7fe",
            "e481024b210446c9bfe6dd01c222badc",
            "adcb43e6af794301af2d293dee02cfae",
            "e49e945c84da44e1a9e83c1b062895f2",
            "c66fb5047a714baea3b60da040d4885d",
            "4f2c2f387bd34b7a946a924d28d06512",
            "9e8e9024a0e84af78fc232994ac9603b",
            "c58dcf3b09564eefb270ca54d38b1ce3",
            "69c264f9c97f49b789e7402f4425ca82",
            "84bd07ee3e6e4c969d5953d8230e2191",
            "cf19a9a733b141ed9d7647145c17bf82",
            "221af6eaafb341f3b629b6058b6725d3",
            "7b50eeeb182641779220739cb1a3911c",
            "77bfee349ff14975a3aee2fbc0999399",
            "c931f85d242b43ef9e1a0114e8e4b507",
            "236e7cc9c317461ca072d3b7292861de",
            "4837a31c4b8d409fa4e454b1b5ab32de",
            "8b55796b7a474123aadd27d5175e245f",
            "53a3f5de9ec3420d99649bd080580bcc",
            "8a0cbccd477e4d729f6a231492480a1b",
            "7287055921464efda101a065562fda70",
            "22f8bf8d9d1c4dfda3c5095dd80fd5d6",
            "4ee58848744f4ffc99799d85a7290786",
            "529b644ec0e245cb80b88e3f0401f590",
            "72bf61c3726447229cc207db31a2d3cf",
            "b8ebcb0c9f2844c6aa6995da80343167",
            "6e44050baff942b9a72d61d8ed3a581d",
            "7997aebdc5a244ceb69bba6f91d2a53d",
            "d17f4c4270944659b369bcbea60a1ccb"
          ]
        },
        "id": "aUN16KIKSz4t",
        "outputId": "08c79b30-1011-4dae-a9ae-17907e534bd7"
      },
      "execution_count": 32,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n",
            "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
            "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
            "You will be able to reuse this secret in all of your notebooks.\n",
            "Please note that authentication is recommended but still optional to access public models or datasets.\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "e9038992d0a94189aea25e7aafd0d955"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "vocab.txt:   0%|          | 0.00/255k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "c58dcf3b09564eefb270ca54d38b1ce3"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "tokenizer.json:   0%|          | 0.00/485k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "53a3f5de9ec3420d99649bd080580bcc"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.\n",
            "WARNING:huggingface_hub.utils._http:Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "**Note:** The follwoing training pipeline will take time in particualr if run without GPUs"
      ],
      "metadata": {
        "id": "PmtyE7BD2ZfX"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import numpy as np\n",
        "from transformers import BertTokenizerFast, BertModel, Trainer, TrainingArguments\n",
        "from sklearn.metrics import precision_score, recall_score, f1_score, classification_report\n",
        "\n",
        "# Make sure GPU is available\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "print(f\"Using device: {device}\")\n",
        "\n",
        "# Load tokenizer and base model\n",
        "tokenizer = BertTokenizerFast.from_pretrained(\"bert-base-german-cased\")\n",
        "\n",
        "# Define your dataset (continuing from where you stopped)\n",
        "class MultiLabelDataset(torch.utils.data.Dataset):\n",
        "    def __init__(self, encodings, labels):\n",
        "        self.encodings = encodings\n",
        "        self.labels = torch.tensor(labels, dtype=torch.float32)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        return {\n",
        "            key: torch.tensor(val[idx]) for key, val in self.encodings.items()\n",
        "        } | {'labels': self.labels[idx]}\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.labels)\n",
        "\n",
        "train_dataset = MultiLabelDataset(train_encodings, y_train)\n",
        "val_dataset = MultiLabelDataset(val_encodings, y_val)\n",
        "\n",
        "# Compute positive class weights (for imbalance handling)\n",
        "label_counts = y_train.sum(axis=0)\n",
        "label_freqs = label_counts / y_train.shape[0]\n",
        "pos_weights = 1.0 / (label_freqs + 1e-6)  # Add epsilon to avoid div-by-zero\n",
        "pos_weights = torch.tensor(pos_weights, dtype=torch.float32).to(device)\n",
        "\n",
        "# Define model with BERT + classification head\n",
        "class BertForMultiLabelClassification(nn.Module):\n",
        "    def __init__(self, num_labels, pos_weights):\n",
        "        super().__init__()\n",
        "        self.bert = BertModel.from_pretrained(\"bert-base-german-cased\")\n",
        "        self.dropout = nn.Dropout(0.3)\n",
        "        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)\n",
        "        self.pos_weights = pos_weights\n",
        "\n",
        "    def forward(self, input_ids=None, attention_mask=None, labels=None):\n",
        "        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n",
        "        pooled_output = self.dropout(outputs.pooler_output)\n",
        "        logits = self.classifier(pooled_output)\n",
        "\n",
        "        loss = None\n",
        "        if labels is not None:\n",
        "            loss_fct = nn.BCEWithLogitsLoss(pos_weight=self.pos_weights)\n",
        "            loss = loss_fct(logits, labels)\n",
        "\n",
        "        return {\"loss\": loss, \"logits\": logits}\n",
        "\n",
        "# Instantiate model\n",
        "num_labels = y_train.shape[1]\n",
        "model = BertForMultiLabelClassification(num_labels=num_labels, pos_weights=pos_weights)\n",
        "model.to(device)\n",
        "\n",
        "# Custom compute_metrics for macro F1\n",
        "\n",
        "def compute_metrics(pred):\n",
        "    logits, labels = pred\n",
        "    preds = (logits > 0).astype(int)\n",
        "\n",
        "    return {\n",
        "        'precision_macro': precision_score(labels, preds, average='macro', zero_division=0),\n",
        "        'recall_macro': recall_score(labels, preds, average='macro', zero_division=0),\n",
        "        'f1_macro': f1_score(labels, preds, average='macro', zero_division=0),\n",
        "        'precision_micro': precision_score(labels, preds, average='micro', zero_division=0),\n",
        "        'recall_micro': recall_score(labels, preds, average='micro', zero_division=0),\n",
        "        'f1_micro': f1_score(labels, preds, average='micro', zero_division=0)\n",
        "    }\n",
        "\n",
        "# Define Trainer args\n",
        "training_args = TrainingArguments(\n",
        "    output_dir=\"./results\",\n",
        "    num_train_epochs=10,\n",
        "    per_device_train_batch_size=16,\n",
        "    per_device_eval_batch_size=32,\n",
        "    eval_strategy=\"epoch\",\n",
        "    save_strategy=\"epoch\",\n",
        "    logging_strategy=\"epoch\",\n",
        "    learning_rate=2e-5,\n",
        "    weight_decay=0.01,\n",
        "    load_best_model_at_end=True,\n",
        "    metric_for_best_model=\"f1_macro\",\n",
        "    save_total_limit=2,\n",
        "    report_to=\"none\",\n",
        ")\n",
        "\n",
        "# Wrap Trainer\n",
        "trainer = Trainer(\n",
        "    model=model,\n",
        "    args=training_args,\n",
        "    train_dataset=train_dataset,\n",
        "    eval_dataset=val_dataset,\n",
        "    compute_metrics=compute_metrics,\n",
        ")\n",
        "\n",
        "# Train the model\n",
        "trainer.train()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 813,
          "referenced_widgets": [
            "0c1e20e2145d4281b633472ed9457467",
            "b05effd9c89744f081bd286257b6a874",
            "09d388cd36b74b54a6652a647b1f5951",
            "78f0e7a863254ed98245fa498c78c838",
            "52ad3d6aa2594f9491beba5e0ecbd3ad",
            "090b34941f9541d7b50d08fa254774eb",
            "b58a0bc45a21441f9c949e9f71b91c6c",
            "51302eec9fc54c4a8d603c2903ebcbc3",
            "5e892e68b5f84ca8bd9d28b5282674c8",
            "f808a1633ee14b47b2013fff4b9e45b8",
            "2c51ba3cfe13470e826182ea5e15c5b6",
            "891331f6d3fc44838aac504c0d0a09b0",
            "b0c7334da03a4d019697751e92e3b6d2",
            "50e6edc971714e479c5ed4bde7f0e97b",
            "a85a5b3d241945ffaaadabc516c0d262",
            "a839295991cb428ab9a0b01bca4cd188",
            "df4272233fd64460a5ef2cd526835e77",
            "724ee2ed34454891b235f53279abef49",
            "c907e791e4dc495aa1aeab3c8aeafd5f",
            "1d475887cba24d34ad44d02eb64f86bc",
            "269564898d6b4c37acbf0ab9f3b3a07c",
            "75c308037d4c405aaffe888fe5c82c43",
            "134ce7ce58b245ffbb6341120c321f72",
            "106d65461c3e40a08732d506b0b52bcd",
            "3b0645d3b94044c5a25490e89d177326",
            "36f31ddf7b93406d9da76c206fecf812",
            "13d3d1305df442f89acc84af217ebc26",
            "5955b8362ff64eb89aaff3e998c06364",
            "24c80a73020d469aad49b1979ac05b62",
            "8ee757eb7908481f85ff5c536fc8a214",
            "97dc681f605d48a58ec498f6042c3aab",
            "61474933d4564ae4ba82b7b7c9f3bc76",
            "f423bb07cbc9412ba7af675097173edc"
          ]
        },
        "id": "hb7vxfaKS2gK",
        "outputId": "c15947b9-e936-4ac5-f952-62da631849bc"
      },
      "execution_count": 33,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Using device: cuda\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "config.json:   0%|          | 0.00/433 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "0c1e20e2145d4281b633472ed9457467"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "model.safetensors:   0%|          | 0.00/439M [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "891331f6d3fc44838aac504c0d0a09b0"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "134ce7ce58b245ffbb6341120c321f72"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "BertModel LOAD REPORT from: bert-base-german-cased\n",
            "Key                                        | Status     |  | \n",
            "-------------------------------------------+------------+--+-\n",
            "cls.seq_relationship.bias                  | UNEXPECTED |  | \n",
            "cls.seq_relationship.weight                | UNEXPECTED |  | \n",
            "cls.predictions.bias                       | UNEXPECTED |  | \n",
            "cls.predictions.transform.LayerNorm.weight | UNEXPECTED |  | \n",
            "cls.predictions.transform.dense.weight     | UNEXPECTED |  | \n",
            "cls.predictions.transform.dense.bias       | UNEXPECTED |  | \n",
            "cls.predictions.transform.LayerNorm.bias   | UNEXPECTED |  | \n",
            "\n",
            "Notes:\n",
            "- UNEXPECTED\t:can be ignored when loading from different task/architecture; not ok if you expect identical arch.\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='1647' max='4250' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [1647/4250 02:27 < 03:53, 11.16 it/s, Epoch 3.87/10]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Precision Macro</th>\n",
              "      <th>Recall Macro</th>\n",
              "      <th>F1 Macro</th>\n",
              "      <th>Precision Micro</th>\n",
              "      <th>Recall Micro</th>\n",
              "      <th>F1 Micro</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.676197</td>\n",
              "      <td>0.317485</td>\n",
              "      <td>0.666141</td>\n",
              "      <td>0.886450</td>\n",
              "      <td>0.736711</td>\n",
              "      <td>0.724444</td>\n",
              "      <td>0.905556</td>\n",
              "      <td>0.804938</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.287085</td>\n",
              "      <td>0.356448</td>\n",
              "      <td>0.824623</td>\n",
              "      <td>0.877898</td>\n",
              "      <td>0.848126</td>\n",
              "      <td>0.826733</td>\n",
              "      <td>0.927778</td>\n",
              "      <td>0.874346</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.174397</td>\n",
              "      <td>0.462729</td>\n",
              "      <td>0.855620</td>\n",
              "      <td>0.873782</td>\n",
              "      <td>0.864125</td>\n",
              "      <td>0.894444</td>\n",
              "      <td>0.894444</td>\n",
              "      <td>0.894444</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "error",
          "ename": "KeyboardInterrupt",
          "evalue": "",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
            "\u001b[0;32m/tmp/ipython-input-2577662762.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m    103\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    104\u001b[0m \u001b[0;31m# Train the model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 105\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/trainer.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m   2172\u001b[0m                 \u001b[0mhf_hub_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menable_progress_bars\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2173\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2174\u001b[0;31m             return inner_training_loop(\n\u001b[0m\u001b[1;32m   2175\u001b[0m                 \u001b[0margs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2176\u001b[0m                 \u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/trainer.py\u001b[0m in \u001b[0;36m_inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m   2534\u001b[0m                     )\n\u001b[1;32m   2535\u001b[0m                     \u001b[0;32mwith\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2536\u001b[0;31m                         \u001b[0mtr_loss_step\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_items_in_batch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2537\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2538\u001b[0m                     if (\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/trainer.py\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, model, inputs, num_items_in_batch)\u001b[0m\n\u001b[1;32m   3806\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0mloss_mb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreduce_mean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3807\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3808\u001b[0;31m             \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_loss_context_manager\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   3809\u001b[0m                 \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_items_in_batch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnum_items_in_batch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3810\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/trainer.py\u001b[0m in \u001b[0;36mcompute_loss_context_manager\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   3753\u001b[0m         \u001b[0mA\u001b[0m \u001b[0mhelper\u001b[0m \u001b[0mwrapper\u001b[0m \u001b[0mto\u001b[0m \u001b[0mgroup\u001b[0m \u001b[0mtogether\u001b[0m \u001b[0mcontext\u001b[0m \u001b[0mmanagers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3754\u001b[0m         \"\"\"\n\u001b[0;32m-> 3755\u001b[0;31m         \u001b[0mctx_stack\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcontextlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mExitStack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   3756\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3757\u001b[0m         \u001b[0mautocast_ctx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautocast_smart_context_manager\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/lib/python3.12/contextlib.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    479\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0m_exit_wrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    480\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 481\u001b[0;31m     \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    482\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_exit_callbacks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdeque\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    483\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Here we coul save the model (disabled)"
      ],
      "metadata": {
        "id": "UT4eu5XLS8Xx"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# save_path = \"/content/drive/models/bert_multilabel_group_model\"\n",
        "# model.bert.save_pretrained(save_path)\n",
        "# tokenizer.save_pretrained(save_path)"
      ],
      "metadata": {
        "id": "iluSCTg4S8sE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# trainer.save_model(save_path)"
      ],
      "metadata": {
        "id": "w_LV2deeTAZ9"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# import os\n",
        "# os.listdir(save_path)"
      ],
      "metadata": {
        "id": "njJAIBZDTCD2"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 2.Second round training with revised training data"
      ],
      "metadata": {
        "id": "4FNIoTdOpHgi"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import pandas as pd\n",
        "group_training_df = pd.read_excel('/content/drive/all_sentences_df_clean_kw_checked_added.xlsx')"
      ],
      "metadata": {
        "id": "DEm-xr7VpIGZ"
      },
      "execution_count": 51,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 2.1 Extracting the labels"
      ],
      "metadata": {
        "id": "bW_X4RG47Rvy"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import pandas as pd\n",
        "import re\n",
        "\n",
        "\n",
        "# Function to extract all annotations like [G_l_m_neu]\n",
        "def extract_tags(text):\n",
        "    if isinstance(text, str): # Check if the text is a string\n",
        "        return re.findall(r'\\[G_([^\\]]+)\\]', text)\n",
        "    else:\n",
        "        return [] # Return an empty list for non-string values\n",
        "\n",
        "# Apply to create new column with list of annotation components\n",
        "group_training_df['tags'] = group_training_df['sentence'].apply(extract_tags)"
      ],
      "metadata": {
        "id": "CRsTMHWhpLwQ"
      },
      "execution_count": 35,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Codebook mappings\n",
        "dimension_codes = {\n",
        "    'b': 'Ökonomisch - Bedürftig',\n",
        "    'w': 'Ökonomisch - Wohlhabend',\n",
        "    'm': 'Eigentümer - Mieter',\n",
        "    'v': 'Eigentümer - Vermieter/Eigentümer',\n",
        "    'e': 'Wohnstruktur - Einparteienhaus',\n",
        "    'mp': 'Wohnstruktur - Mehrparteienhaus',\n",
        "    'l': 'Geographie - Land',\n",
        "    's': 'Geographie - Stadt',\n",
        "    'o': 'Other - All other'\n",
        "}\n",
        "\n",
        "appeal_codes = {\n",
        "    'pos': 'positiv',\n",
        "    'neu': 'neutral',\n",
        "    'neg': 'negativ'\n",
        "}\n",
        "\n",
        "# Function to classify each tag\n",
        "def classify_tag_components(tag_str):\n",
        "    parts = tag_str.split('_')\n",
        "    result = {'dimensions': [], 'appeal': None}\n",
        "    for part in parts:\n",
        "        if part in appeal_codes:\n",
        "            result['appeal'] = appeal_codes[part]\n",
        "        elif part in dimension_codes:\n",
        "            result['dimensions'].append(dimension_codes[part])\n",
        "        else:\n",
        "            # For compound codes like 'mp' (Mehrparteienhaus)\n",
        "            compound_matches = [code for code in dimension_codes if code in part]\n",
        "            for match in compound_matches:\n",
        "                if match == part:\n",
        "                    result['dimensions'].append(dimension_codes[match])\n",
        "    return result\n",
        "\n",
        "# Expand all tag components into dimensions and appeal per row\n",
        "def analyze_tags(tag_list):\n",
        "    results = []\n",
        "    for tag in tag_list:\n",
        "        results.append(classify_tag_components(tag))\n",
        "    return results\n",
        "\n",
        "group_training_df['tag_details'] = group_training_df['tags'].apply(analyze_tags)"
      ],
      "metadata": {
        "id": "yqEiau3JpPry"
      },
      "execution_count": 36,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from collections import Counter\n",
        "\n",
        "all_dimensions = group_training_df['tag_details'].explode().dropna().apply(lambda x: x['dimensions'])\n",
        "flat_dimensions = [dim for sublist in all_dimensions.dropna() for dim in sublist]\n",
        "dimension_counts = Counter(flat_dimensions)\n",
        "print(dimension_counts)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "kchdazw0pR5P",
        "outputId": "9189d2a8-5f08-427f-fab7-ec26b6e424dc"
      },
      "execution_count": 37,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Counter({'Other - All other': 790, 'Eigentümer - Vermieter/Eigentümer': 422, 'Ökonomisch - Bedürftig': 283, 'Eigentümer - Mieter': 248, 'Ökonomisch - Wohlhabend': 89, 'Wohnstruktur - Mehrparteienhaus': 9, 'Geographie - Land': 8, 'Wohnstruktur - Einparteienhaus': 7, 'Geographie - Stadt': 7})\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "all_appeals = group_training_df['tag_details'].explode().dropna().apply(lambda x: x['appeal'])\n",
        "appeal_counts = Counter(all_appeals)\n",
        "print(appeal_counts)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "rBX_4L_ypTZe",
        "outputId": "46d41657-6365-488e-ad98-c73a484a5680"
      },
      "execution_count": 38,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Counter({'positiv': 1046, 'neutral': 586, 'negativ': 131, None: 25})\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Create exploded version for analysis\n",
        "rows = []\n",
        "for idx, tag_list in group_training_df['tag_details'].items():\n",
        "    for tag in tag_list:\n",
        "        for dim in tag['dimensions']:\n",
        "            rows.append({\n",
        "                'article_id': idx,\n",
        "                'dimension': dim,\n",
        "                'appeal': tag['appeal']\n",
        "            })\n",
        "\n",
        "analysis_df = pd.DataFrame(rows)"
      ],
      "metadata": {
        "id": "TrHf_fInpU-V"
      },
      "execution_count": 39,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Frequency table\n",
        "pd.crosstab(analysis_df['dimension'], analysis_df['appeal'])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 363
        },
        "id": "O9LnuP3VpWxx",
        "outputId": "423c2dfe-a99d-4113-b78a-3b904c7dd0f9"
      },
      "execution_count": 40,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "appeal                             negativ  neutral  positiv\n",
              "dimension                                                   \n",
              "Eigentümer - Mieter                      2       67      179\n",
              "Eigentümer - Vermieter/Eigentümer       20      186      216\n",
              "Geographie - Land                        0        1        7\n",
              "Geographie - Stadt                       0        5        2\n",
              "Other - All other                       32      278      457\n",
              "Wohnstruktur - Einparteienhaus           1        5        1\n",
              "Wohnstruktur - Mehrparteienhaus          0        5        4\n",
              "Ökonomisch - Bedürftig                  11       41      230\n",
              "Ökonomisch - Wohlhabend                 67       18        4"
            ],
            "text/html": [
              "\n",
              "  <div id=\"df-021c3525-91ed-45ed-8d96-ba900a099b0c\" class=\"colab-df-container\">\n",
              "    <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th>appeal</th>\n",
              "      <th>negativ</th>\n",
              "      <th>neutral</th>\n",
              "      <th>positiv</th>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>dimension</th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>Eigentümer - Mieter</th>\n",
              "      <td>2</td>\n",
              "      <td>67</td>\n",
              "      <td>179</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>Eigentümer - Vermieter/Eigentümer</th>\n",
              "      <td>20</td>\n",
              "      <td>186</td>\n",
              "      <td>216</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>Geographie - Land</th>\n",
              "      <td>0</td>\n",
              "      <td>1</td>\n",
              "      <td>7</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>Geographie - Stadt</th>\n",
              "      <td>0</td>\n",
              "      <td>5</td>\n",
              "      <td>2</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>Other - All other</th>\n",
              "      <td>32</td>\n",
              "      <td>278</td>\n",
              "      <td>457</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>Wohnstruktur - Einparteienhaus</th>\n",
              "      <td>1</td>\n",
              "      <td>5</td>\n",
              "      <td>1</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>Wohnstruktur - Mehrparteienhaus</th>\n",
              "      <td>0</td>\n",
              "      <td>5</td>\n",
              "      <td>4</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>Ökonomisch - Bedürftig</th>\n",
              "      <td>11</td>\n",
              "      <td>41</td>\n",
              "      <td>230</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>Ökonomisch - Wohlhabend</th>\n",
              "      <td>67</td>\n",
              "      <td>18</td>\n",
              "      <td>4</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "    <div class=\"colab-df-buttons\">\n",
              "\n",
              "  <div class=\"colab-df-container\">\n",
              "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-021c3525-91ed-45ed-8d96-ba900a099b0c')\"\n",
              "            title=\"Convert this dataframe to an interactive table.\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
              "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
              "  </svg>\n",
              "    </button>\n",
              "\n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    .colab-df-buttons div {\n",
              "      margin-bottom: 4px;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "    <script>\n",
              "      const buttonEl =\n",
              "        document.querySelector('#df-021c3525-91ed-45ed-8d96-ba900a099b0c button.colab-df-convert');\n",
              "      buttonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "      async function convertToInteractive(key) {\n",
              "        const element = document.querySelector('#df-021c3525-91ed-45ed-8d96-ba900a099b0c');\n",
              "        const dataTable =\n",
              "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                    [key], {});\n",
              "        if (!dataTable) return;\n",
              "\n",
              "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "          + ' to learn more about interactive tables.';\n",
              "        element.innerHTML = '';\n",
              "        dataTable['output_type'] = 'display_data';\n",
              "        await google.colab.output.renderOutput(dataTable, element);\n",
              "        const docLink = document.createElement('div');\n",
              "        docLink.innerHTML = docLinkHtml;\n",
              "        element.appendChild(docLink);\n",
              "      }\n",
              "    </script>\n",
              "  </div>\n",
              "\n",
              "\n",
              "    </div>\n",
              "  </div>\n"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "dataframe",
              "summary": "{\n  \"name\": \"pd\",\n  \"rows\": 9,\n  \"fields\": [\n    {\n      \"column\": \"dimension\",\n      \"properties\": {\n        \"dtype\": \"string\",\n        \"num_unique_values\": 9,\n        \"samples\": [\n          \"\\u00d6konomisch - Bed\\u00fcrftig\",\n          \"Eigent\\u00fcmer - Vermieter/Eigent\\u00fcmer\",\n          \"Wohnstruktur - Einparteienhaus\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"negativ\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 22,\n        \"min\": 0,\n        \"max\": 67,\n        \"num_unique_values\": 7,\n        \"samples\": [\n          2,\n          20,\n          11\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"neutral\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 98,\n        \"min\": 1,\n        \"max\": 278,\n        \"num_unique_values\": 7,\n        \"samples\": [\n          67,\n          186,\n          41\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"positiv\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 160,\n        \"min\": 1,\n        \"max\": 457,\n        \"num_unique_values\": 8,\n        \"samples\": [\n          216,\n          1,\n          179\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    }\n  ]\n}"
            }
          },
          "metadata": {},
          "execution_count": 40
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Define label mapping (including 'none')\n",
        "label_map = {\n",
        "    'b': 'Bedürftig',\n",
        "    'w': 'Wohlhabend',\n",
        "    'm': 'Mieter',\n",
        "    'v': 'Vermieter/Eigentümer'\n",
        "}\n",
        "target_labels = list(label_map.values())"
      ],
      "metadata": {
        "id": "xl8eKzJtpYb0"
      },
      "execution_count": 41,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Function to extract the relevant dimensions (b, w, m, v)\n",
        "def extract_target_dimensions(tag_details):\n",
        "    labels = set()\n",
        "    for detail in tag_details:\n",
        "        for dim in detail['dimensions']:\n",
        "            for code, name in label_map.items():\n",
        "                if dimension_codes[code] == dim:\n",
        "                    labels.add(name)\n",
        "    return list(labels)\n",
        "\n",
        "# Apply to get labels\n",
        "group_training_df['sentence_labels'] = group_training_df['tag_details'].apply(extract_target_dimensions)"
      ],
      "metadata": {
        "id": "-aWdT9aWpaER"
      },
      "execution_count": 42,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Create multi-hot encoded label columns\n",
        "for label in target_labels:\n",
        "    group_training_df[label] = group_training_df['sentence_labels'].apply(lambda x: label in x)"
      ],
      "metadata": {
        "id": "2Nz2ffA9pb4k"
      },
      "execution_count": 43,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Function to remove [G_...] annotation tags from text\n",
        "def remove_tags(text):\n",
        "    if isinstance(text, str):\n",
        "        return re.sub(r'\\[G_[^\\]]+\\]', '', text).strip()\n",
        "    return text\n",
        "\n",
        "# Apply to create a clean sentence column\n",
        "group_training_df['clean_sentence'] = group_training_df['sentence'].apply(remove_tags)"
      ],
      "metadata": {
        "id": "bQMqFo6-pdZi"
      },
      "execution_count": 44,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(\"Label distribution in the full dataset:\")\n",
        "print(group_training_df[target_labels].sum().sort_values(ascending=False))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "XB6fYA44pfQp",
        "outputId": "5c9a217b-4943-4c5e-ee19-b5a7ffd8d7ed"
      },
      "execution_count": 45,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Label distribution in the full dataset:\n",
            "Vermieter/Eigentümer    401\n",
            "Bedürftig               261\n",
            "Mieter                  240\n",
            "Wohlhabend               88\n",
            "dtype: int64\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 2.2 Train and test split"
      ],
      "metadata": {
        "id": "0RKHuet663Zf"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from sklearn.model_selection import train_test_split\n",
        "\n",
        "X = group_training_df['clean_sentence']\n",
        "y = group_training_df[target_labels]\n",
        "\n",
        "X_train, X_test, y_train, y_test = train_test_split(\n",
        "    X, y, test_size=0.2, random_state=42\n",
        ")"
      ],
      "metadata": {
        "id": "5lGS857npg9P"
      },
      "execution_count": 46,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(\"\\nLabel distribution in the training set:\")\n",
        "print(y_train.sum().sort_values(ascending=False))\n",
        "\n",
        "print(\"\\nLabel distribution in the test set:\")\n",
        "print(y_test.sum().sort_values(ascending=False))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "vULjqNNKpi1j",
        "outputId": "b8ba1bc2-c911-4d3c-e3d2-d16aa49ea832"
      },
      "execution_count": 47,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "Label distribution in the training set:\n",
            "Vermieter/Eigentümer    317\n",
            "Bedürftig               197\n",
            "Mieter                  180\n",
            "Wohlhabend               67\n",
            "dtype: int64\n",
            "\n",
            "Label distribution in the test set:\n",
            "Vermieter/Eigentümer    84\n",
            "Bedürftig               64\n",
            "Mieter                  60\n",
            "Wohlhabend              21\n",
            "dtype: int64\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "pip install transformers datasets scikit-learn"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "VIFRs_jjpkWx",
        "outputId": "157b10a5-e6b5-4875-a5cc-c7fc602c7a43"
      },
      "execution_count": 48,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Requirement already satisfied: transformers in /usr/local/lib/python3.12/dist-packages (5.0.0)\n",
            "Requirement already satisfied: datasets in /usr/local/lib/python3.12/dist-packages (4.0.0)\n",
            "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.12/dist-packages (1.6.1)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from transformers) (3.20.3)\n",
            "Requirement already satisfied: huggingface-hub<2.0,>=1.3.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (1.4.0)\n",
            "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from transformers) (2.0.2)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (26.0)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from transformers) (6.0.3)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from transformers) (2025.11.3)\n",
            "Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.22.2)\n",
            "Requirement already satisfied: typer-slim in /usr/local/lib/python3.12/dist-packages (from transformers) (0.21.1)\n",
            "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.7.0)\n",
            "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.12/dist-packages (from transformers) (4.67.3)\n",
            "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (18.1.0)\n",
            "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (0.3.8)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from datasets) (2.2.2)\n",
            "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.12/dist-packages (from datasets) (2.32.4)\n",
            "Requirement already satisfied: xxhash in /usr/local/lib/python3.12/dist-packages (from datasets) (3.6.0)\n",
            "Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.12/dist-packages (from datasets) (0.70.16)\n",
            "Requirement already satisfied: fsspec<=2025.3.0,>=2023.1.0 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (2025.3.0)\n",
            "Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn) (1.16.3)\n",
            "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn) (1.5.3)\n",
            "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn) (3.6.0)\n",
            "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (3.13.3)\n",
            "Requirement already satisfied: hf-xet<2.0.0,>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<2.0,>=1.3.0->transformers) (1.2.0)\n",
            "Requirement already satisfied: httpx<1,>=0.23.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<2.0,>=1.3.0->transformers) (0.28.1)\n",
            "Requirement already satisfied: shellingham in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<2.0,>=1.3.0->transformers) (1.5.4)\n",
            "Requirement already satisfied: typing-extensions>=4.1.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<2.0,>=1.3.0->transformers) (4.15.0)\n",
            "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (3.4.4)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (3.11)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (2.5.0)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (2026.1.4)\n",
            "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2.9.0.post0)\n",
            "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2025.2)\n",
            "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2025.3)\n",
            "Requirement already satisfied: click>=8.0.0 in /usr/local/lib/python3.12/dist-packages (from typer-slim->transformers) (8.3.1)\n",
            "Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (2.6.1)\n",
            "Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.4.0)\n",
            "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (25.4.0)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.8.0)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (6.7.1)\n",
            "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (0.4.1)\n",
            "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.22.0)\n",
            "Requirement already satisfied: anyio in /usr/local/lib/python3.12/dist-packages (from httpx<1,>=0.23.0->huggingface-hub<2.0,>=1.3.0->transformers) (4.12.1)\n",
            "Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.12/dist-packages (from httpx<1,>=0.23.0->huggingface-hub<2.0,>=1.3.0->transformers) (1.0.9)\n",
            "Requirement already satisfied: h11>=0.16 in /usr/local/lib/python3.12/dist-packages (from httpcore==1.*->httpx<1,>=0.23.0->huggingface-hub<2.0,>=1.3.0->transformers) (0.16.0)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from sklearn.model_selection import train_test_split\n",
        "from transformers import BertTokenizer\n",
        "\n",
        "# HuggingFace model + tokenizer\n",
        "model_name = \"bert-base-german-cased\"\n",
        "tokenizer = BertTokenizer.from_pretrained(model_name)\n",
        "\n",
        "# Inputs and labels\n",
        "X = group_training_df['clean_sentence'].tolist()\n",
        "y = group_training_df[target_labels].values.astype(float)\n",
        "\n",
        "# Split\n",
        "X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)\n",
        "\n",
        "# Tokenize\n",
        "train_encodings = tokenizer(X_train, truncation=True, padding=True, max_length=128)\n",
        "val_encodings = tokenizer(X_val, truncation=True, padding=True, max_length=128)\n"
      ],
      "metadata": {
        "id": "gt_syJzypmx6"
      },
      "execution_count": 49,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 2.3 Training the model"
      ],
      "metadata": {
        "id": "HgGmgsYv7YhH"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import numpy as np\n",
        "from transformers import BertTokenizerFast, BertModel, Trainer, TrainingArguments\n",
        "from sklearn.metrics import precision_score, recall_score, f1_score, classification_report\n",
        "\n",
        "# Make sure GPU is available\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "print(f\"Using device: {device}\")\n",
        "\n",
        "# Load tokenizer and base model\n",
        "tokenizer = BertTokenizerFast.from_pretrained(\"bert-base-german-cased\")\n",
        "\n",
        "# Define your dataset (continuing from where you stopped)\n",
        "class MultiLabelDataset(torch.utils.data.Dataset):\n",
        "    def __init__(self, encodings, labels):\n",
        "        self.encodings = encodings\n",
        "        self.labels = torch.tensor(labels, dtype=torch.float32)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        return {\n",
        "            key: torch.tensor(val[idx]) for key, val in self.encodings.items()\n",
        "        } | {'labels': self.labels[idx]}\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.labels)\n",
        "\n",
        "train_dataset = MultiLabelDataset(train_encodings, y_train)\n",
        "val_dataset = MultiLabelDataset(val_encodings, y_val)\n",
        "\n",
        "# Compute positive class weights (for imbalance handling)\n",
        "label_counts = y_train.sum(axis=0)\n",
        "label_freqs = label_counts / y_train.shape[0]\n",
        "pos_weights = 1.0 / (label_freqs + 1e-6)  # Add epsilon to avoid div-by-zero\n",
        "pos_weights = torch.tensor(pos_weights, dtype=torch.float32).to(device)\n",
        "\n",
        "# Define model with BERT + classification head\n",
        "class BertForMultiLabelClassification(nn.Module):\n",
        "    def __init__(self, num_labels, pos_weights):\n",
        "        super().__init__()\n",
        "        self.bert = BertModel.from_pretrained(\"bert-base-german-cased\")\n",
        "        self.dropout = nn.Dropout(0.3)\n",
        "        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)\n",
        "        self.pos_weights = pos_weights\n",
        "\n",
        "    def forward(self, input_ids=None, attention_mask=None, labels=None):\n",
        "        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n",
        "        pooled_output = self.dropout(outputs.pooler_output)\n",
        "        logits = self.classifier(pooled_output)\n",
        "\n",
        "        loss = None\n",
        "        if labels is not None:\n",
        "            loss_fct = nn.BCEWithLogitsLoss(pos_weight=self.pos_weights)\n",
        "            loss = loss_fct(logits, labels)\n",
        "\n",
        "        return {\"loss\": loss, \"logits\": logits}\n",
        "\n",
        "# Instantiate model\n",
        "num_labels = y_train.shape[1]\n",
        "model = BertForMultiLabelClassification(num_labels=num_labels, pos_weights=pos_weights)\n",
        "model.to(device)\n",
        "\n",
        "# Custom compute_metrics for macro F1\n",
        "\n",
        "def compute_metrics(pred):\n",
        "    logits, labels = pred\n",
        "    preds = (logits > 0).astype(int)\n",
        "\n",
        "    return {\n",
        "        'precision_macro': precision_score(labels, preds, average='macro', zero_division=0),\n",
        "        'recall_macro': recall_score(labels, preds, average='macro', zero_division=0),\n",
        "        'f1_macro': f1_score(labels, preds, average='macro', zero_division=0),\n",
        "        'precision_micro': precision_score(labels, preds, average='micro', zero_division=0),\n",
        "        'recall_micro': recall_score(labels, preds, average='micro', zero_division=0),\n",
        "        'f1_micro': f1_score(labels, preds, average='micro', zero_division=0)\n",
        "    }\n",
        "\n",
        "# Define Trainer args\n",
        "training_args = TrainingArguments(\n",
        "    output_dir=\"./results\",\n",
        "    num_train_epochs=10,\n",
        "    per_device_train_batch_size=16,\n",
        "    per_device_eval_batch_size=32,\n",
        "    eval_strategy=\"epoch\",\n",
        "    save_strategy=\"epoch\",\n",
        "    logging_strategy=\"epoch\",\n",
        "    learning_rate=2e-5,\n",
        "    weight_decay=0.01,\n",
        "    load_best_model_at_end=True,\n",
        "    metric_for_best_model=\"f1_macro\",\n",
        "    save_total_limit=2,\n",
        "    report_to=\"none\",\n",
        ")\n",
        "\n",
        "# Wrap Trainer\n",
        "trainer = Trainer(\n",
        "    model=model,\n",
        "    args=training_args,\n",
        "    train_dataset=train_dataset,\n",
        "    eval_dataset=val_dataset,\n",
        "    compute_metrics=compute_metrics,\n",
        ")\n",
        "\n",
        "# Train the model\n",
        "trainer.train()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 699,
          "referenced_widgets": [
            "e507ccc6226844d581337d7b1376bc67",
            "08d8f19593c3491f8d4a2675e9712c9a",
            "1f6085c27430495daa40d2a0bd0f03af",
            "c991712fbc484f3ba2f1ed943d301ef6",
            "875d4f5d37fe4a5d883abed1995a78e7",
            "8a1bfb11b19143f3bfc9a2b85f9f96e7",
            "7e57800838e54f87851f292fc95c1565",
            "6a32bc8713af4b4ba506c30fe4e1dd8b",
            "bd653ceffd3c406f97d10fd6b39a8b63",
            "32bcfcb821f344ee8543157a68853029",
            "ea6770024da94054b766fd4aa62c801b"
          ]
        },
        "id": "tX5_oYHEpq8h",
        "outputId": "2794ccf4-5ef7-4f70-b646-29383595cf09"
      },
      "execution_count": 50,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Using device: cuda\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "e507ccc6226844d581337d7b1376bc67"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "BertModel LOAD REPORT from: bert-base-german-cased\n",
            "Key                                        | Status     |  | \n",
            "-------------------------------------------+------------+--+-\n",
            "cls.seq_relationship.bias                  | UNEXPECTED |  | \n",
            "cls.seq_relationship.weight                | UNEXPECTED |  | \n",
            "cls.predictions.bias                       | UNEXPECTED |  | \n",
            "cls.predictions.transform.LayerNorm.weight | UNEXPECTED |  | \n",
            "cls.predictions.transform.dense.weight     | UNEXPECTED |  | \n",
            "cls.predictions.transform.dense.bias       | UNEXPECTED |  | \n",
            "cls.predictions.transform.LayerNorm.bias   | UNEXPECTED |  | \n",
            "\n",
            "Notes:\n",
            "- UNEXPECTED\t:can be ignored when loading from different task/architecture; not ok if you expect identical arch.\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='4310' max='4310' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [4310/4310 06:35, Epoch 10/10]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Precision Macro</th>\n",
              "      <th>Recall Macro</th>\n",
              "      <th>F1 Macro</th>\n",
              "      <th>Precision Micro</th>\n",
              "      <th>Recall Micro</th>\n",
              "      <th>F1 Micro</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.551153</td>\n",
              "      <td>0.486345</td>\n",
              "      <td>0.855627</td>\n",
              "      <td>0.846689</td>\n",
              "      <td>0.849065</td>\n",
              "      <td>0.827004</td>\n",
              "      <td>0.855895</td>\n",
              "      <td>0.841202</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.223903</td>\n",
              "      <td>0.495483</td>\n",
              "      <td>0.928521</td>\n",
              "      <td>0.847247</td>\n",
              "      <td>0.881261</td>\n",
              "      <td>0.909091</td>\n",
              "      <td>0.873362</td>\n",
              "      <td>0.890869</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.137706</td>\n",
              "      <td>0.360235</td>\n",
              "      <td>0.837911</td>\n",
              "      <td>0.901935</td>\n",
              "      <td>0.866458</td>\n",
              "      <td>0.866109</td>\n",
              "      <td>0.903930</td>\n",
              "      <td>0.884615</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>0.101688</td>\n",
              "      <td>0.451579</td>\n",
              "      <td>0.863364</td>\n",
              "      <td>0.875632</td>\n",
              "      <td>0.869025</td>\n",
              "      <td>0.880342</td>\n",
              "      <td>0.899563</td>\n",
              "      <td>0.889849</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>5</td>\n",
              "      <td>0.051779</td>\n",
              "      <td>0.390733</td>\n",
              "      <td>0.876374</td>\n",
              "      <td>0.906585</td>\n",
              "      <td>0.890573</td>\n",
              "      <td>0.901288</td>\n",
              "      <td>0.917031</td>\n",
              "      <td>0.909091</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>6</td>\n",
              "      <td>0.023449</td>\n",
              "      <td>0.523280</td>\n",
              "      <td>0.851495</td>\n",
              "      <td>0.888988</td>\n",
              "      <td>0.869256</td>\n",
              "      <td>0.869748</td>\n",
              "      <td>0.903930</td>\n",
              "      <td>0.886510</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>7</td>\n",
              "      <td>0.029234</td>\n",
              "      <td>0.644249</td>\n",
              "      <td>0.912942</td>\n",
              "      <td>0.871912</td>\n",
              "      <td>0.891722</td>\n",
              "      <td>0.918552</td>\n",
              "      <td>0.886463</td>\n",
              "      <td>0.902222</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>8</td>\n",
              "      <td>0.022747</td>\n",
              "      <td>0.590192</td>\n",
              "      <td>0.876405</td>\n",
              "      <td>0.902604</td>\n",
              "      <td>0.888438</td>\n",
              "      <td>0.896104</td>\n",
              "      <td>0.903930</td>\n",
              "      <td>0.900000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>9</td>\n",
              "      <td>0.018384</td>\n",
              "      <td>0.563870</td>\n",
              "      <td>0.898058</td>\n",
              "      <td>0.881585</td>\n",
              "      <td>0.889015</td>\n",
              "      <td>0.899123</td>\n",
              "      <td>0.895197</td>\n",
              "      <td>0.897155</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>10</td>\n",
              "      <td>0.009771</td>\n",
              "      <td>0.577029</td>\n",
              "      <td>0.900951</td>\n",
              "      <td>0.881585</td>\n",
              "      <td>0.890401</td>\n",
              "      <td>0.903084</td>\n",
              "      <td>0.895197</td>\n",
              "      <td>0.899123</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "TrainOutput(global_step=4310, training_loss=0.11698136163698272, metrics={'train_runtime': 395.7109, 'train_samples_per_second': 174.066, 'train_steps_per_second': 10.892, 'total_flos': 0.0, 'train_loss': 0.11698136163698272, 'epoch': 10.0})"
            ]
          },
          "metadata": {},
          "execution_count": 50
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Predict on validation set\n",
        "outputs = trainer.predict(val_dataset)\n",
        "preds = (outputs.predictions > 0).astype(int)\n",
        "print(classification_report(y_val, preds, target_names=target_labels))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 349
        },
        "id": "R9iUNgRYp4Wd",
        "outputId": "747f53e4-b750-4208-e808-80c59ca48ff9"
      },
      "execution_count": 52,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": []
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "                      precision    recall  f1-score   support\n",
            "\n",
            "           Bedürftig       0.87      0.83      0.85        64\n",
            "          Wohlhabend       0.89      0.81      0.85        21\n",
            "              Mieter       0.95      0.93      0.94        60\n",
            "Vermieter/Eigentümer       0.94      0.92      0.93        84\n",
            "\n",
            "           micro avg       0.92      0.89      0.90       229\n",
            "           macro avg       0.91      0.87      0.89       229\n",
            "        weighted avg       0.92      0.89      0.90       229\n",
            "         samples avg       0.10      0.10      0.10       229\n",
            "\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in samples with no predicted labels. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
            "/usr/local/lib/python3.12/dist-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in samples with no true labels. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
            "/usr/local/lib/python3.12/dist-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in samples with no true nor predicted labels. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Here we would save the model (disabled)"
      ],
      "metadata": {
        "id": "cLGjuOP45f1h"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# save_path = \"/content/drive/bert_multilabel_group_model_2\"\n",
        "# model.bert.save_pretrained(save_path)\n",
        "# tokenizer.save_pretrained(save_path)"
      ],
      "metadata": {
        "id": "wuUscjz9p7Cr"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# trainer.save_model(save_path)"
      ],
      "metadata": {
        "id": "oorXbYxkp8mw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 3.Modelbased prediction in full corpus of parties and media data"
      ],
      "metadata": {
        "id": "CZRhx_ytrKds"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import BertPreTrainedModel, BertModel\n",
        "import torch.nn as nn\n",
        "\n",
        "class BertForMultiLabelClassification(BertPreTrainedModel):\n",
        "    def __init__(self, config, pos_weights=None):\n",
        "        super().__init__(config)\n",
        "        self.num_labels = config.num_labels\n",
        "        self.bert = BertModel(config)\n",
        "        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n",
        "        self.classifier = nn.Linear(config.hidden_size, config.num_labels)\n",
        "        self.pos_weights = pos_weights\n",
        "        self.init_weights()\n",
        "\n",
        "    def forward(self, input_ids=None, attention_mask=None, labels=None):\n",
        "        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n",
        "        pooled_output = outputs[1]\n",
        "        pooled_output = self.dropout(pooled_output)\n",
        "        logits = self.classifier(pooled_output)\n",
        "\n",
        "        loss = None\n",
        "        if labels is not None:\n",
        "            loss_fct = nn.BCEWithLogitsLoss(pos_weight=self.pos_weights.to(logits.device) if self.pos_weights is not None else None)\n",
        "            loss = loss_fct(logits, labels)\n",
        "\n",
        "        return {'loss': loss, 'logits': logits} if loss is not None else {'logits': logits}"
      ],
      "metadata": {
        "id": "kj7YWUQNrNOd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
      ],
      "metadata": {
        "id": "vOFwIzPQrO9p"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import numpy as np\n",
        "\n",
        "# Make sure y_train is available in the current session\n",
        "# If not, reload it or recompute it from your train split\n",
        "\n",
        "label_counts = y_train.sum(axis=0)\n",
        "label_freqs = label_counts / y_train.shape[0]\n",
        "pos_weights = 1.0 / (label_freqs + 1e-6)  # epsilon for numerical safety\n",
        "pos_weights = torch.tensor(pos_weights, dtype=torch.float32).to(device)"
      ],
      "metadata": {
        "id": "nOKZoXq2rQk-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import BertTokenizerFast, BertConfig\n",
        "\n",
        "# Set save path\n",
        "save_path = \"/content/drive/bert_multilabel_group_model_2\"\n",
        "\n",
        "# Load tokenizer\n",
        "tokenizer = BertTokenizerFast.from_pretrained(save_path)\n",
        "\n",
        "# Load config and make sure it matches the number of labels used during training\n",
        "config = BertConfig.from_pretrained(save_path)\n",
        "config.num_labels = 4  # Match with training\n",
        "\n",
        "# Load model using config and pos_weights\n",
        "model = BertForMultiLabelClassification.from_pretrained(save_path, config=config, pos_weights=pos_weights)\n",
        "model.to(device)"
      ],
      "metadata": {
        "id": "QX59-6QurScO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 3.1 Predicting group appeal in partisan statements"
      ],
      "metadata": {
        "id": "c1f0AhAW7l8p"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import pandas as pd\n",
        "\n",
        "party_df = pd.read_feather('/content/drive/parties_df_topic_climate_labels.feather')  # adjust the path"
      ],
      "metadata": {
        "id": "D1OhSX98rU2I"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "pip install stanza"
      ],
      "metadata": {
        "id": "fC9RbmBUrXqY"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import stanza\n",
        "\n",
        "# Download the German model\n",
        "stanza.download('de')"
      ],
      "metadata": {
        "id": "GhDK2D2wrZ2-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Load German pipeline (tokenize only)\n",
        "nlp = stanza.Pipeline(lang='de', processors='tokenize', use_gpu=True)\n",
        "\n",
        "# Example usage on a DataFrame\n",
        "def split_sentences_stanza(text):\n",
        "    if isinstance(text, str) and text.strip():\n",
        "        doc = nlp(text)\n",
        "        return [sentence.text for sentence in doc.sentences]\n",
        "    else:\n",
        "        return []\n",
        "\n",
        "# Apply to your filtered media dataset\n",
        "party_df_filtered = party_df[party_df['predicted_label'] == 1].copy()\n",
        "party_df_filtered['sentences'] = party_df_filtered['text'].apply(split_sentences_stanza)"
      ],
      "metadata": {
        "id": "T6zxcxc0rb4P"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from torch.utils.data import DataLoader, Dataset\n",
        "from transformers import BertTokenizerFast\n",
        "from tqdm import tqdm\n",
        "\n",
        "# Define your label mapping\n",
        "label_map = ['Bedürftig', 'Wohlhabend', 'Mieter', 'Vermieter/Eigentümer']\n",
        "\n",
        "# Flatten all sentences from relevant articles\n",
        "party_df_filtered = party_df_filtered[party_df_filtered['predicted_label'] == 1].copy()\n",
        "party_df_filtered['sentences'] = party_df_filtered['text'].apply(split_sentences_stanza)\n",
        "\n",
        "all_sentences = []\n",
        "article_ids = []\n",
        "for idx, row in party_df_filtered.iterrows():\n",
        "    for sent in row['sentences']:\n",
        "        all_sentences.append(sent)\n",
        "        article_ids.append(idx)  # keep reference to original article\n",
        "\n",
        "# Tokenize\n",
        "tokenizer = BertTokenizerFast.from_pretrained(save_path)\n",
        "encodings = tokenizer(all_sentences, truncation=True, padding=True, return_tensors='pt')\n",
        "\n",
        "# Create Dataset\n",
        "class PredictionDataset(Dataset):\n",
        "    def __init__(self, encodings):\n",
        "        self.encodings = encodings\n",
        "    def __getitem__(self, idx):\n",
        "        return {key: val[idx] for key, val in self.encodings.items()}\n",
        "    def __len__(self):\n",
        "        return len(self.encodings['input_ids'])\n",
        "\n",
        "pred_dataset = PredictionDataset(encodings)\n",
        "pred_loader = DataLoader(pred_dataset, batch_size=32)\n",
        "\n",
        "# Move model to GPU\n",
        "model.eval()\n",
        "model.to(device)\n",
        "\n",
        "# Run predictions\n",
        "all_logits = []\n",
        "with torch.no_grad():\n",
        "    for batch in tqdm(pred_loader):\n",
        "        batch = {k: v.to(device) for k, v in batch.items() if k != 'token_type_ids'}\n",
        "        outputs = model(**batch)\n",
        "        logits = outputs['logits']\n",
        "        all_logits.append(logits.cpu())\n",
        "\n",
        "# Concatenate and apply sigmoid\n",
        "logits_tensor = torch.cat(all_logits, dim=0)\n",
        "probs = torch.sigmoid(logits_tensor)\n",
        "\n",
        "# Binary predictions\n",
        "threshold = 0.3\n",
        "predictions = (probs > threshold).int().tolist()\n",
        "\n",
        "# Construct result DataFrame\n",
        "pred_df = pd.DataFrame({\n",
        "    'article_id': article_ids,\n",
        "    'sentence': all_sentences,\n",
        "    'pred_labels': predictions\n",
        "})\n",
        "\n",
        "# Expand labels into text\n",
        "def get_label_names(binary_preds):\n",
        "    return [label_map[i] for i, v in enumerate(binary_preds) if v == 1]\n",
        "\n",
        "pred_df['label_names'] = pred_df['pred_labels'].apply(get_label_names)"
      ],
      "metadata": {
        "id": "gTIiCff1rhSX"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Add probability columns\n",
        "for i, label in enumerate(label_map):\n",
        "    pred_df[f'prob_{label}'] = probs[:, i].numpy()"
      ],
      "metadata": {
        "id": "Ee5O-ofTrjQ6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Here we would save the prediction (disabled)"
      ],
      "metadata": {
        "id": "ydsGqG9u5vAd"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "pred_df.reset_index(drop=True).to_feather('/content/drive/parties_df_topic_climate_and_group_labels.feather')"
      ],
      "metadata": {
        "id": "rBGmN3cXrluq"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "pred_df.to_excel('/content/drive/parties_df_group_labels.xlsx', index=True)"
      ],
      "metadata": {
        "id": "eVIGXAgHroQp"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 3.2 Modelbased predicition of group appeal in media data"
      ],
      "metadata": {
        "id": "KBygQzwbsCDH"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import pandas as pd\n",
        "\n",
        "media_df = pd.read_feather('/content/drive/media_df_topic_climate_labels.feather')  # adjust the path"
      ],
      "metadata": {
        "id": "t54KU9rCsEuu"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install stanza\n",
        "import stanza\n",
        "# Load German pipeline (tokenize only)\n",
        "nlp = stanza.Pipeline(lang='de', processors='tokenize', use_gpu=True)\n",
        "\n",
        "# Example usage on a DataFrame\n",
        "def split_sentences_stanza(text):\n",
        "    if isinstance(text, str) and text.strip():\n",
        "        doc = nlp(text)\n",
        "        return [sentence.text for sentence in doc.sentences]\n",
        "    else:\n",
        "        return []\n",
        "\n",
        "# Apply to your filtered media dataset\n",
        "media_df_filtered = media_df[media_df['predicted_label'] == 1].copy()\n",
        "media_df_filtered['sentences'] = media_df_filtered['text'].apply(split_sentences_stanza)"
      ],
      "metadata": {
        "id": "x40mSrv3sGm6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from torch.utils.data import DataLoader, Dataset\n",
        "from transformers import BertTokenizerFast\n",
        "from tqdm import tqdm\n",
        "\n",
        "# Define your label mapping\n",
        "label_map = ['Bedürftig', 'Wohlhabend', 'Mieter', 'Vermieter/Eigentümer']\n",
        "\n",
        "# Flatten all sentences from relevant articles\n",
        "media_df_filtered = media_df[media_df['predicted_label'] == 1].copy()\n",
        "media_df_filtered['sentences'] = media_df_filtered['text'].apply(split_sentences_stanza)\n",
        "\n",
        "all_sentences = []\n",
        "article_ids = []\n",
        "for idx, row in media_df_filtered.iterrows():\n",
        "    for sent in row['sentences']:\n",
        "        all_sentences.append(sent)\n",
        "        article_ids.append(idx)  # keep reference to original article\n",
        "\n",
        "# Tokenize\n",
        "tokenizer = BertTokenizerFast.from_pretrained(save_path)\n",
        "encodings = tokenizer(all_sentences, truncation=True, padding=True, return_tensors='pt')\n",
        "\n",
        "# Create Dataset\n",
        "class PredictionDataset(Dataset):\n",
        "    def __init__(self, encodings):\n",
        "        self.encodings = encodings\n",
        "    def __getitem__(self, idx):\n",
        "        return {key: val[idx] for key, val in self.encodings.items()}\n",
        "    def __len__(self):\n",
        "        return len(self.encodings['input_ids'])\n",
        "\n",
        "pred_dataset = PredictionDataset(encodings)\n",
        "pred_loader = DataLoader(pred_dataset, batch_size=32)\n",
        "\n",
        "# Move model to GPU\n",
        "model.eval()\n",
        "model.to(device)\n",
        "\n",
        "# Run predictions\n",
        "all_logits = []\n",
        "with torch.no_grad():\n",
        "    for batch in tqdm(pred_loader):\n",
        "        batch = {k: v.to(device) for k, v in batch.items() if k != 'token_type_ids'}\n",
        "        outputs = model(**batch)\n",
        "        logits = outputs['logits']\n",
        "        all_logits.append(logits.cpu())\n",
        "\n",
        "# Concatenate and apply sigmoid\n",
        "logits_tensor = torch.cat(all_logits, dim=0)\n",
        "probs = torch.sigmoid(logits_tensor)\n",
        "\n",
        "# Binary predictions\n",
        "threshold = 0.3\n",
        "predictions = (probs > threshold).int().tolist()\n",
        "\n",
        "# Construct result DataFrame\n",
        "pred_df_media = pd.DataFrame({\n",
        "    'article_id': article_ids,\n",
        "    'sentence': all_sentences,\n",
        "    'pred_labels': predictions\n",
        "})\n",
        "\n",
        "# Expand labels into text\n",
        "def get_label_names(binary_preds):\n",
        "    return [label_map[i] for i, v in enumerate(binary_preds) if v == 1]\n",
        "\n",
        "pred_df_media['label_names'] = pred_df_media['pred_labels'].apply(get_label_names)"
      ],
      "metadata": {
        "id": "IKkotH8IsJIw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Add probability columns\n",
        "for i, label in enumerate(label_map):\n",
        "    pred_df_media[f'prob_{label}'] = probs[:, i].numpy()"
      ],
      "metadata": {
        "id": "N2VOMhA2sLFK"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "pred_df_media.to_excel('/content/drive/media_df_group_label.xlsx', index=False)"
      ],
      "metadata": {
        "id": "lCP_BbZ5sM3t"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Reset index to ensure clean mapping from article_id\n",
        "media_df_filtered = media_df[media_df['predicted_label'] == 1].copy().reset_index()  # adds old index as 'index'\n",
        "media_df_filtered.rename(columns={'index': 'article_id'}, inplace=True)"
      ],
      "metadata": {
        "id": "uDD9vSuVsPEo"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Select article_id and newspaper for merging\n",
        "meta_cols = ['article_id', 'newspaper']  # Add more columns if needed\n",
        "meta_df = media_df_filtered[meta_cols]\n",
        "\n",
        "# Merge predicted sentences back with metadata\n",
        "pred_df_media = pred_df_media.merge(meta_df, on='article_id', how='left')"
      ],
      "metadata": {
        "id": "5L_VHpqosQvN"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Here we would save the prediction (disabled)"
      ],
      "metadata": {
        "id": "bH8yS41S52ZC"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# pred_df_media.to_excel('/content/drive/media_df_group_label_n.xlsx', index=False)"
      ],
      "metadata": {
        "id": "KN8RE3Q8sSXw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 4. Final revision\n",
        "\n",
        "We ultimately revised the final predicitions by screening the positive prediction."
      ],
      "metadata": {
        "id": "ofTs1_F2tVtV"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "- Final media data: media_ml_predictions_with_metadata_revised.xlsx\n",
        "\n",
        "- Final partisan data: ml_predictions_metadata.xlsx"
      ],
      "metadata": {
        "id": "aJZu9RJCtfBp"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "These final data have been used in the file \"Reproduction of figures\" to generate the figures used in the Technical Report"
      ],
      "metadata": {
        "id": "3XB3vxS48SUV"
      }
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "2WtLqCYrteV1"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}